Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #ifndef LARANK_H
00050 #define LARANK_H
00051
00052 #include <ctime>
00053 #include <vector>
00054 #include <algorithm>
00055 #include <sys/time.h>
00056 #include <ext/hash_map>
00057 #include <ext/hash_set>
00058
00059 #define STDEXT_NAMESPACE __gnu_cxx
00060 #define std_hash_map STDEXT_NAMESPACE::hash_map
00061 #define std_hash_set STDEXT_NAMESPACE::hash_set
00062
00063 #include "lib/io.h"
00064 #include "kernel/Kernel.h"
00065 #include "classifier/svm/MultiClassSVM.h"
00066
00067 namespace shogun
00068 {
00069 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00070 struct larank_kcache_s;
00071 typedef struct larank_kcache_s larank_kcache_t;
00072 struct larank_kcache_s
00073 {
00074 CKernel* func;
00075 larank_kcache_t *prevbuddy;
00076 larank_kcache_t *nextbuddy;
00077 int64_t maxsize;
00078 int64_t cursize;
00079 int32_t l;
00080 int32_t *i2r;
00081 int32_t *r2i;
00082 int32_t maxrowlen;
00083
00084 int32_t *rsize;
00085 float32_t *rdiag;
00086 float32_t **rdata;
00087 int32_t *rnext;
00088 int32_t *rprev;
00089 int32_t *qnext;
00090 int32_t *qprev;
00091 };
00092
00093
00094
00095
00096
00097 class LaRankOutput
00098 {
00099 public:
00100 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
00101 {
00102 }
00103 virtual ~LaRankOutput ()
00104 {
00105 destroy();
00106 }
00107
00108
00109 void initialize (CKernel* kfunc, int64_t cache);
00110
00111
00112 void destroy ();
00113
00114
00115 float64_t computeScore (int32_t x_id);
00116
00117
00118 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
00119
00120
00121 void update (int32_t x_id, float64_t lambda, float64_t gp);
00122
00123
00124
00125
00126 void set_kernel_buddy (larank_kcache_t * bud);
00127
00128
00129 int32_t cleanup ();
00130
00131
00132
00133
00134 inline larank_kcache_t *getKernel () const
00135 {
00136 return kernel;
00137 }
00138
00139 inline int32_t get_l () const
00140 {
00141 return l;
00142 }
00143
00144
00145 float64_t getW2 ();
00146
00147
00148 float64_t getKii (int32_t x_id);
00149
00150
00151 float64_t getBeta (int32_t x_id);
00152
00153
00154 inline float32_t* getBetas () const
00155 {
00156 return beta;
00157 }
00158
00159
00160 float64_t getGradient (int32_t x_id);
00161
00162
00163 bool isSupportVector (int32_t x_id) const;
00164
00165
00166 int32_t getSV (float32_t* &sv) const;
00167
00168 private:
00169
00170
00171 float32_t* beta;
00172 float32_t* g;
00173 larank_kcache_t *kernel;
00174 int32_t l;
00175 };
00176
00177
00178
00179
00180 class LaRankPattern
00181 {
00182 public:
00183 LaRankPattern (int32_t x_index, int32_t label)
00184 : x_id (x_index), y (label) {}
00185 LaRankPattern ()
00186 : x_id (0) {}
00187
00188 bool exists () const
00189 {
00190 return x_id >= 0;
00191 }
00192
00193 void clear ()
00194 {
00195 x_id = -1;
00196 }
00197
00198 int32_t x_id;
00199 int32_t y;
00200 };
00201
00202
00203
00204
00205 class LaRankPatterns
00206 {
00207 public:
00208 LaRankPatterns () {}
00209 ~LaRankPatterns () {}
00210
00211 void insert (const LaRankPattern & pattern)
00212 {
00213 if (!isPattern (pattern.x_id))
00214 {
00215 if (freeidx.size ())
00216 {
00217 std_hash_set < uint32_t >::iterator it = freeidx.begin ();
00218 patterns[*it] = pattern;
00219 x_id2rank[pattern.x_id] = *it;
00220 freeidx.erase (it);
00221 }
00222 else
00223 {
00224 patterns.push_back (pattern);
00225 x_id2rank[pattern.x_id] = patterns.size () - 1;
00226 }
00227 }
00228 else
00229 {
00230 int32_t rank = getPatternRank (pattern.x_id);
00231 patterns[rank] = pattern;
00232 }
00233 }
00234
00235 void remove (uint32_t i)
00236 {
00237 x_id2rank[patterns[i].x_id] = 0;
00238 patterns[i].clear ();
00239 freeidx.insert (i);
00240 }
00241
00242 bool empty () const
00243 {
00244 return patterns.size () == freeidx.size ();
00245 }
00246
00247 uint32_t size () const
00248 {
00249 return patterns.size () - freeidx.size ();
00250 }
00251
00252 LaRankPattern & sample ()
00253 {
00254 ASSERT (!empty ());
00255 while (true)
00256 {
00257 uint32_t r = CMath::random(0, patterns.size ()-1);
00258 if (patterns[r].exists ())
00259 return patterns[r];
00260 }
00261 return patterns[0];
00262 }
00263
00264 uint32_t getPatternRank (int32_t x_id)
00265 {
00266 return x_id2rank[x_id];
00267 }
00268
00269 bool isPattern (int32_t x_id)
00270 {
00271 return x_id2rank[x_id] != 0;
00272 }
00273
00274 LaRankPattern & getPattern (int32_t x_id)
00275 {
00276 uint32_t rank = x_id2rank[x_id];
00277 return patterns[rank];
00278 }
00279
00280 uint32_t maxcount () const
00281 {
00282 return patterns.size ();
00283 }
00284
00285 LaRankPattern & operator [] (uint32_t i)
00286 {
00287 return patterns[i];
00288 }
00289
00290 const LaRankPattern & operator [] (uint32_t i) const
00291 {
00292 return patterns[i];
00293 }
00294
00295 private:
00296 std_hash_set < uint32_t >freeidx;
00297 std::vector < LaRankPattern > patterns;
00298 std_hash_map < int32_t, uint32_t >x_id2rank;
00299 };
00300
00301
00302 #endif // DOXYGEN_SHOULD_SKIP_THIS
00303
00304
00305
00306
00307
00308 class CLaRank: public CMultiClassSVM
00309 {
00310 public:
00311 CLaRank ();
00312
00319 CLaRank(float64_t C, CKernel* k, CLabels* lab);
00320
00321 virtual ~CLaRank ();
00322
00323 bool train(CFeatures* data);
00324
00325
00326
00327
00328 virtual int32_t add (int32_t x_id, int32_t yi);
00329
00330
00331 virtual int32_t predict (int32_t x_id);
00332
00333 virtual void destroy ();
00334
00335
00336 virtual float64_t computeGap ();
00337
00338
00339 virtual uint32_t getNumOutputs () const;
00340
00341
00342 int32_t getNSV ();
00343
00344
00345 float64_t computeW2 ();
00346
00347
00348 float64_t getDual ();
00349
00354 virtual inline EClassifierType get_classifier_type() { return CT_LARANK; }
00355
00357 inline virtual const char* get_name() const { return "LaRank"; }
00358
00359 void set_batch_mode(bool enable) { batch_mode=enable; };
00360 bool get_batch_mode() { return batch_mode; };
00361 void set_tau(float64_t t) { tau=t; };
00362 float64_t get_tau() { return tau; };
00363
00364
00365 private:
00366
00367
00368
00369
00370
00371 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t;
00372
00373
00374 outputhash_t outputs;
00375 LaRankOutput *getOutput (int32_t index);
00376
00377
00378 LaRankPatterns patterns;
00379
00380
00381 int32_t nb_seen_examples;
00382 int32_t nb_removed;
00383
00384
00385 int32_t n_pro;
00386 int32_t n_rep;
00387 int32_t n_opt;
00388
00389
00390 float64_t w_pro;
00391 float64_t w_rep;
00392 float64_t w_opt;
00393
00394 int32_t y0;
00395 float64_t dual;
00396
00397 struct outputgradient_t
00398 {
00399 outputgradient_t (int32_t result_output, float64_t result_gradient)
00400 : output (result_output), gradient (result_gradient) {}
00401 outputgradient_t ()
00402 : output (0), gradient (0) {}
00403
00404 int32_t output;
00405 float64_t gradient;
00406
00407 bool operator < (const outputgradient_t & og) const
00408 {
00409 return gradient > og.gradient;
00410 }
00411 };
00412
00413
00414 enum process_type
00415 {
00416 processNew,
00417 processOld,
00418 processOptimize
00419 };
00420
00421 struct process_return_t
00422 {
00423 process_return_t (float64_t dual, int32_t yprediction)
00424 : dual_increase (dual), ypred (yprediction) {}
00425 process_return_t () {}
00426 float64_t dual_increase;
00427 int32_t ypred;
00428 };
00429
00430
00431 process_return_t process (const LaRankPattern & pattern, process_type ptype);
00432
00433
00434 float64_t reprocess ();
00435
00436
00437 float64_t optimize ();
00438
00439
00440 uint32_t cleanup ();
00441
00442 protected:
00443
00444 std_hash_set < int32_t >classes;
00445
00446 inline uint32_t class_count () const
00447 {
00448 return classes.size ();
00449 }
00450
00451 float64_t tau;
00452 int32_t nb_train;
00453 int64_t cache;
00454
00455 bool batch_mode;
00456
00457
00458 int32_t step;
00459 };
00460 }
00461 #endif // LARANK_H