Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #ifndef LARANK_H
00050 #define LARANK_H
00051
00052 #include <ctime>
00053 #include <vector>
00054 #include <algorithm>
00055 #include <sys/time.h>
00056 #include <set>
00057 #include <map>
00058 #define STDEXT_NAMESPACE __gnu_cxx
00059 #define std_hash_map std::map
00060 #define std_hash_set std::set
00061
00062 #include <shogun/io/SGIO.h>
00063 #include <shogun/kernel/Kernel.h>
00064 #include <shogun/multiclass/MulticlassSVM.h>
00065
00066 namespace shogun
00067 {
00068 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00069 struct larank_kcache_s;
00070 typedef struct larank_kcache_s larank_kcache_t;
00071 struct larank_kcache_s
00072 {
00073 CKernel* func;
00074 larank_kcache_t *prevbuddy;
00075 larank_kcache_t *nextbuddy;
00076 int64_t maxsize;
00077 int64_t cursize;
00078 int32_t l;
00079 int32_t *i2r;
00080 int32_t *r2i;
00081 int32_t maxrowlen;
00082
00083 int32_t *rsize;
00084 float32_t *rdiag;
00085 float32_t **rdata;
00086 int32_t *rnext;
00087 int32_t *rprev;
00088 int32_t *qnext;
00089 int32_t *qprev;
00090 };
00091
00092
00093
00094
00095
00096 class LaRankOutput
00097 {
00098 public:
00099 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
00100 {
00101 }
00102 virtual ~LaRankOutput ()
00103 {
00104 destroy();
00105 }
00106
00107
00108 void initialize (CKernel* kfunc, int64_t cache);
00109
00110
00111 void destroy ();
00112
00113
00114 float64_t computeScore (int32_t x_id);
00115
00116
00117 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
00118
00119
00120 void update (int32_t x_id, float64_t lambda, float64_t gp);
00121
00122
00123
00124
00125 void set_kernel_buddy (larank_kcache_t * bud);
00126
00127
00128 int32_t cleanup ();
00129
00130
00131
00132
00133 inline larank_kcache_t *getKernel () const
00134 {
00135 return kernel;
00136 }
00137
00138 inline int32_t get_l () const
00139 {
00140 return l;
00141 }
00142
00143
00144 float64_t getW2 ();
00145
00146
00147 float64_t getKii (int32_t x_id);
00148
00149
00150 float64_t getBeta (int32_t x_id);
00151
00152
00153 inline float32_t* getBetas () const
00154 {
00155 return beta;
00156 }
00157
00158
00159 float64_t getGradient (int32_t x_id);
00160
00161
00162 bool isSupportVector (int32_t x_id) const;
00163
00164
00165 int32_t getSV (float32_t* &sv) const;
00166
00167 private:
00168
00169
00170 float32_t* beta;
00171 float32_t* g;
00172 larank_kcache_t *kernel;
00173 int32_t l;
00174 };
00175
00176
00177
00178
00179 class LaRankPattern
00180 {
00181 public:
00182 LaRankPattern (int32_t x_index, int32_t label)
00183 : x_id (x_index), y (label) {}
00184 LaRankPattern ()
00185 : x_id (0) {}
00186
00187 bool exists () const
00188 {
00189 return x_id >= 0;
00190 }
00191
00192 void clear ()
00193 {
00194 x_id = -1;
00195 }
00196
00197 int32_t x_id;
00198 int32_t y;
00199 };
00200
00201
00202
00203
00204 class LaRankPatterns
00205 {
00206 public:
00207 LaRankPatterns () {}
00208 ~LaRankPatterns () {}
00209
00210 void insert (const LaRankPattern & pattern)
00211 {
00212 if (!isPattern (pattern.x_id))
00213 {
00214 if (freeidx.size ())
00215 {
00216 std_hash_set < uint32_t >::iterator it = freeidx.begin ();
00217 patterns[*it] = pattern;
00218 x_id2rank[pattern.x_id] = *it;
00219 freeidx.erase (it);
00220 }
00221 else
00222 {
00223 patterns.push_back (pattern);
00224 x_id2rank[pattern.x_id] = patterns.size () - 1;
00225 }
00226 }
00227 else
00228 {
00229 int32_t rank = getPatternRank (pattern.x_id);
00230 patterns[rank] = pattern;
00231 }
00232 }
00233
00234 void remove (uint32_t i)
00235 {
00236 x_id2rank[patterns[i].x_id] = 0;
00237 patterns[i].clear ();
00238 freeidx.insert (i);
00239 }
00240
00241 bool empty () const
00242 {
00243 return patterns.size () == freeidx.size ();
00244 }
00245
00246 uint32_t size () const
00247 {
00248 return patterns.size () - freeidx.size ();
00249 }
00250
00251 LaRankPattern & sample ()
00252 {
00253 ASSERT (!empty ());
00254 while (true)
00255 {
00256 uint32_t r = CMath::random(uint32_t(0), uint32_t(patterns.size ()-1));
00257 if (patterns[r].exists ())
00258 return patterns[r];
00259 }
00260 return patterns[0];
00261 }
00262
00263 uint32_t getPatternRank (int32_t x_id)
00264 {
00265 return x_id2rank[x_id];
00266 }
00267
00268 bool isPattern (int32_t x_id)
00269 {
00270 return x_id2rank[x_id] != 0;
00271 }
00272
00273 LaRankPattern & getPattern (int32_t x_id)
00274 {
00275 uint32_t rank = x_id2rank[x_id];
00276 return patterns[rank];
00277 }
00278
00279 uint32_t maxcount () const
00280 {
00281 return patterns.size ();
00282 }
00283
00284 LaRankPattern & operator [] (uint32_t i)
00285 {
00286 return patterns[i];
00287 }
00288
00289 const LaRankPattern & operator [] (uint32_t i) const
00290 {
00291 return patterns[i];
00292 }
00293
00294 private:
00295 std_hash_set < uint32_t >freeidx;
00296 std::vector < LaRankPattern > patterns;
00297 std_hash_map < int32_t, uint32_t >x_id2rank;
00298 };
00299
00300
00301 #endif // DOXYGEN_SHOULD_SKIP_THIS
00302
00303
00307 class CLaRank: public CMulticlassSVM
00308 {
00309 public:
00310 CLaRank ();
00311
00318 CLaRank(float64_t C, CKernel* k, CLabels* lab);
00319
00320 virtual ~CLaRank ();
00321
00322
00323
00328 virtual int32_t add (int32_t x_id, int32_t yi);
00329
00330
00334 virtual int32_t predict (int32_t x_id);
00335
00337 virtual void destroy ();
00338
00339
00341 virtual float64_t computeGap ();
00342
00343
00345 virtual uint32_t getNumOutputs () const;
00346
00347
00349 int32_t getNSV ();
00350
00351
00353 float64_t computeW2 ();
00354
00355
00357 float64_t getDual ();
00358
00363 virtual EMachineType get_classifier_type() { return CT_LARANK; }
00364
00366 virtual const char* get_name() const { return "LaRank"; }
00367
00371 void set_batch_mode(bool enable) { batch_mode=enable; };
00373 bool get_batch_mode() { return batch_mode; };
00377 void set_tau(float64_t t) { tau=t; };
00381 float64_t get_tau() { return tau; };
00382
00383 protected:
00385 bool train_machine(CFeatures* data);
00386
00387 private:
00388
00389
00390
00391
00392
00394 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t;
00395
00397 outputhash_t outputs;
00398
00399 LaRankOutput *getOutput (int32_t index);
00400
00401
00402 LaRankPatterns patterns;
00403
00404
00405 int32_t nb_seen_examples;
00406 int32_t nb_removed;
00407
00408
00409 int32_t n_pro;
00410 int32_t n_rep;
00411 int32_t n_opt;
00412
00413
00414 float64_t w_pro;
00415 float64_t w_rep;
00416 float64_t w_opt;
00417
00418 int32_t y0;
00419 float64_t dual;
00420
00421 struct outputgradient_t
00422 {
00423 outputgradient_t (int32_t result_output, float64_t result_gradient)
00424 : output (result_output), gradient (result_gradient) {}
00425 outputgradient_t ()
00426 : output (0), gradient (0) {}
00427
00428 int32_t output;
00429 float64_t gradient;
00430
00431 bool operator < (const outputgradient_t & og) const
00432 {
00433 return gradient > og.gradient;
00434 }
00435 };
00436
00437
00438 enum process_type
00439 {
00440 processNew,
00441 processOld,
00442 processOptimize
00443 };
00444
00445 struct process_return_t
00446 {
00447 process_return_t (float64_t dual, int32_t yprediction)
00448 : dual_increase (dual), ypred (yprediction) {}
00449 process_return_t () {}
00450 float64_t dual_increase;
00451 int32_t ypred;
00452 };
00453
00454
00455 process_return_t process (const LaRankPattern & pattern, process_type ptype);
00456
00457
00458 float64_t reprocess ();
00459
00460
00461 float64_t optimize ();
00462
00463
00464 uint32_t cleanup ();
00465
00466 protected:
00467
00469 std_hash_set < int32_t >classes;
00470
00472 inline uint32_t class_count () const
00473 {
00474 return classes.size ();
00475 }
00476
00478 float64_t tau;
00479
00481 int32_t nb_train;
00483 int64_t cache;
00485 bool batch_mode;
00486
00488 int32_t step;
00489 };
00490 }
00491 #endif // LARANK_H