SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
LaRank.h
浏览该文件的文档.
1 // -*- C++ -*-
2 // Main functions of the LaRank algorithm for soving Multiclass SVM
3 // Copyright (C) 2008- Antoine Bordes
4 // Shogun specific adjustments (w) 2009 Soeren Sonnenburg
5 
6 // This library is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 2.1 of the License, or (at your option) any later version.
10 //
11 // This program is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
15 //
16 // You should have received a copy of the GNU Lesser General Public
17 // License along with this library; if not, write to the Free Software
18 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 //
20 /***********************************************************************
21  *
22  * LUSH Lisp Universal Shell
23  * Copyright (C) 2002 Leon Bottou, Yann Le Cun, AT&T Corp, NECI.
24  * Includes parts of TL3:
25  * Copyright (C) 1987-1999 Leon Bottou and Neuristique.
26  * Includes selected parts of SN3.2:
27  * Copyright (C) 1991-2001 AT&T Corp.
28  *
29  * This program is free software; you can redistribute it and/or modify
30  * it under the terms of the GNU General Public License as published by
31  * the Free Software Foundation; either version 2 of the License, or
32  * (at your option) any later version.
33  *
34  * This program is distributed in the hope that it will be useful,
35  * but WITHOUT ANY WARRANTY; without even the implied warranty of
36  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
37  * GNU General Public License for more details.
38  *
39  * You should have received a copy of the GNU General Public License
40  * along with this program; if not, write to the Free Software
41  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
42  *
43  ***********************************************************************/
44 
45 /***********************************************************************
46  * $Id: kcache.h,v 1.8 2007/01/25 22:42:09 leonb Exp $
47  **********************************************************************/
48 
49 #ifndef LARANK_H
50 #define LARANK_H
51 
52 #include <vector>
53 #include <set>
54 #include <map>
55 #define STDEXT_NAMESPACE __gnu_cxx
56 #define std_hash_map std::map
57 #define std_hash_set std::set
58 
59 #include <shogun/lib/config.h>
60 
61 #include <shogun/io/SGIO.h>
62 #include <shogun/kernel/Kernel.h>
64 
65 namespace shogun
66 {
67 #ifndef DOXYGEN_SHOULD_SKIP_THIS
68  struct larank_kcache_s;
69  typedef struct larank_kcache_s larank_kcache_t;
70  struct larank_kcache_s
71  {
72  CKernel* func;
73  larank_kcache_t *prevbuddy;
74  larank_kcache_t *nextbuddy;
75  int64_t maxsize;
76  int64_t cursize;
77  int32_t l;
78  int32_t *i2r;
79  int32_t *r2i;
80  int32_t maxrowlen;
81  /* Rows */
82  int32_t *rsize;
83  float32_t *rdiag;
84  float32_t **rdata;
85  int32_t *rnext;
86  int32_t *rprev;
87  int32_t *qnext;
88  int32_t *qprev;
89  };
90 
91  /*
92  ** OUTPUT: one per class of the raining set, keep tracks of support
93  * vectors and their beta coefficients
94  */
95  class LaRankOutput
96  {
97  public:
98  LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
99  {
100  }
101  virtual ~LaRankOutput ()
102  {
103  destroy();
104  }
105 
106  // Initializing an output class (basically creating a kernel cache for it)
107  void initialize (CKernel* kfunc, int64_t cache);
108 
109  // Destroying an output class (basically destroying the kernel cache)
110  void destroy ();
111 
112  // !Important! Computing the score of a given input vector for the actual output
113  float64_t computeScore (int32_t x_id);
114 
115  // !Important! Computing the gradient of a given input vector for the actual output
116  float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
117 
118  // Updating the solution in the actual output
119  void update (int32_t x_id, float64_t lambda, float64_t gp);
120 
121  // Linking the cache of this output to the cache of an other "buddy" output
122  // so that if a requested value is not found in this cache, you can
123  // ask your buddy if it has it.
124  void set_kernel_buddy (larank_kcache_t * bud);
125 
126  // Removing useless support vectors (for which beta=0)
127  int32_t cleanup ();
128 
129  // --- Below are information or "get" functions --- //
130 
131  //
132  inline larank_kcache_t *getKernel () const
133  {
134  return kernel;
135  }
136  //
137  inline int32_t get_l () const
138  {
139  return l;
140  }
141 
142  //
143  float64_t getW2 ();
144 
145  //
146  float64_t getKii (int32_t x_id);
147 
148  //
149  float64_t getBeta (int32_t x_id);
150 
151  //
152  inline float32_t* getBetas () const
153  {
154  return beta;
155  }
156 
157  //
158  float64_t getGradient (int32_t x_id);
159 
160  //
161  bool isSupportVector (int32_t x_id) const;
162 
163  //
164  int32_t getSV (float32_t* &sv) const;
165 
166  private:
167  // the solution of LaRank relative to the actual class is stored in
168  // this parameters
169  float32_t* beta; // Beta coefficiens
170  float32_t* g; // Strored gradient derivatives
171  larank_kcache_t *kernel; // Cache for kernel values
172  int32_t l; // Number of support vectors
173  };
174 
175  /*
176  ** LARANKPATTERN: to keep track of the support patterns
177  */
178  class LaRankPattern
179  {
180  public:
181  LaRankPattern (int32_t x_index, int32_t label)
182  : x_id (x_index), y (label) {}
183  LaRankPattern ()
184  : x_id (0) {}
185 
186  bool exists () const
187  {
188  return x_id >= 0;
189  }
190 
191  void clear ()
192  {
193  x_id = -1;
194  }
195 
196  int32_t x_id;
197  int32_t y;
198  };
199 
200  /*
201  ** LARANKPATTERNS: the collection of support patterns
202  */
203  class LaRankPatterns
204  {
205  public:
206  LaRankPatterns () {}
207  ~LaRankPatterns () {}
208 
209  void insert (const LaRankPattern & pattern)
210  {
211  if (!isPattern (pattern.x_id))
212  {
213  if (freeidx.size ())
214  {
215  std_hash_set < uint32_t >::iterator it = freeidx.begin ();
216  patterns[*it] = pattern;
217  x_id2rank[pattern.x_id] = *it;
218  freeidx.erase (it);
219  }
220  else
221  {
222  patterns.push_back (pattern);
223  x_id2rank[pattern.x_id] = patterns.size () - 1;
224  }
225  }
226  else
227  {
228  int32_t rank = getPatternRank (pattern.x_id);
229  patterns[rank] = pattern;
230  }
231  }
232 
233  void remove (uint32_t i)
234  {
235  x_id2rank[patterns[i].x_id] = 0;
236  patterns[i].clear ();
237  freeidx.insert (i);
238  }
239 
240  bool empty () const
241  {
242  return patterns.size () == freeidx.size ();
243  }
244 
245  uint32_t size () const
246  {
247  return patterns.size () - freeidx.size ();
248  }
249 
250  LaRankPattern & sample ()
251  {
252  ASSERT (!empty ())
253  while (true)
254  {
255  uint32_t r = CMath::random(uint32_t(0), uint32_t(patterns.size ()-1));
256  if (patterns[r].exists ())
257  return patterns[r];
258  }
259  return patterns[0];
260  }
261 
262  uint32_t getPatternRank (int32_t x_id)
263  {
264  return x_id2rank[x_id];
265  }
266 
267  bool isPattern (int32_t x_id)
268  {
269  return x_id2rank[x_id] != 0;
270  }
271 
272  LaRankPattern & getPattern (int32_t x_id)
273  {
274  uint32_t rank = x_id2rank[x_id];
275  return patterns[rank];
276  }
277 
278  uint32_t maxcount () const
279  {
280  return patterns.size ();
281  }
282 
283  LaRankPattern & operator [] (uint32_t i)
284  {
285  return patterns[i];
286  }
287 
288  const LaRankPattern & operator [] (uint32_t i) const
289  {
290  return patterns[i];
291  }
292 
293  private:
294  std_hash_set < uint32_t >freeidx;
295  std::vector < LaRankPattern > patterns;
296  std_hash_map < int32_t, uint32_t >x_id2rank;
297  };
298 
299 
300 #endif // DOXYGEN_SHOULD_SKIP_THIS
301 
302 
306  class CLaRank: public CMulticlassSVM
307  {
308  public:
309  CLaRank ();
310 
317  CLaRank(float64_t C, CKernel* k, CLabels* lab);
318 
319  virtual ~CLaRank ();
320 
321  // LEARNING FUNCTION: add new patterns and run optimization steps
322  // selected with adaptative schedule
327  virtual int32_t add (int32_t x_id, int32_t yi);
328 
329  // PREDICTION FUNCTION: main function in la_rank_classify
333  virtual int32_t predict (int32_t x_id);
334 
336  virtual void destroy ();
337 
338  // Compute Duality gap (costly but used in stopping criteria in batch mode)
340  virtual float64_t computeGap ();
341 
342  // Nuber of classes so far
344  virtual uint32_t getNumOutputs () const;
345 
346  // Number of Support Vectors
348  int32_t getNSV ();
349 
350  // Norm of the parameters vector
352  float64_t computeW2 ();
353 
354  // Compute Dual objective value
356  float64_t getDual ();
357 
363 
365  virtual const char* get_name() const { return "LaRank"; }
366 
370  void set_batch_mode(bool enable) { batch_mode=enable; };
372  bool get_batch_mode() { return batch_mode; };
376  void set_tau(float64_t t) { tau=t; };
380  float64_t get_tau() { return tau; };
381 
382  protected:
384  bool train_machine(CFeatures* data);
385 
386  private:
387  /*
388  ** MAIN DARK OPTIMIZATION PROCESSES
389  */
390 
391  // Hash Table used to store the different outputs
393  typedef std_hash_map < int32_t, LaRankOutput > outputhash_t; // class index -> LaRankOutput
394 
396  outputhash_t outputs;
397 
398  LaRankOutput *getOutput (int32_t index);
399 
400  //
401  LaRankPatterns patterns;
402 
403  // Parameters
404  int32_t nb_seen_examples;
405  int32_t nb_removed;
406 
407  // Numbers of each operation performed so far
408  int32_t n_pro;
409  int32_t n_rep;
410  int32_t n_opt;
411 
412  // Running estimates for each operations
413  float64_t w_pro;
414  float64_t w_rep;
415  float64_t w_opt;
416 
417  int32_t y0;
418  float64_t m_dual;
419 
420  struct outputgradient_t
421  {
422  outputgradient_t (int32_t result_output, float64_t result_gradient)
423  : output (result_output), gradient (result_gradient) {}
424  outputgradient_t ()
425  : output (0), gradient (0) {}
426 
427  int32_t output;
428  float64_t gradient;
429 
430  bool operator < (const outputgradient_t & og) const
431  {
432  return gradient > og.gradient;
433  }
434  };
435 
436  //3 types of operations in LaRank
437  enum process_type
438  {
439  processNew,
440  processOld,
441  processOptimize
442  };
443 
444  struct process_return_t
445  {
446  process_return_t (float64_t dual, int32_t yprediction)
447  : dual_increase (dual), ypred (yprediction) {}
448  process_return_t () {}
449  float64_t dual_increase;
450  int32_t ypred;
451  };
452 
453  // IMPORTANT Main SMO optimization step
454  process_return_t process (const LaRankPattern & pattern, process_type ptype);
455 
456  // ProcessOld
457  float64_t reprocess ();
458 
459  // Optimize
460  float64_t optimize ();
461 
462  // remove patterns and return the number of patterns that were removed
463  uint32_t cleanup ();
464 
465  protected:
466 
468  std_hash_set < int32_t >classes;
469 
471  inline uint32_t class_count () const
472  {
473  return classes.size ();
474  }
475 
478 
480  int32_t nb_train;
482  int64_t cache;
485 
487  int32_t step;
488  };
489 }
490 #endif // LARANK_H
virtual const char * get_name() const
Definition: LaRank.h:365
EMachineType
Definition: Machine.h:33
bool batch_mode
whether to use online learning or batch training
Definition: LaRank.h:484
virtual void destroy()
Definition: LaRank.cpp:795
int32_t getNSV()
Definition: LaRank.cpp:841
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
int64_t cache
cache
Definition: LaRank.h:482
bool train_machine(CFeatures *data)
Definition: LaRank.cpp:609
void set_batch_mode(bool enable)
Definition: LaRank.h:370
float64_t tau
tau
Definition: LaRank.h:477
void(* update)(float *foo, float bar)
Definition: JLCoverTree.h:528
virtual int32_t add(int32_t x_id, int32_t yi)
Definition: LaRank.cpp:702
uint32_t class_count() const
class count
Definition: LaRank.h:471
virtual EMachineType get_classifier_type()
Definition: LaRank.h:362
float64_t get_tau()
Definition: LaRank.h:380
virtual float64_t computeGap()
Definition: LaRank.cpp:804
static uint64_t random()
Definition: Math.h:1019
float64_t computeW2()
Definition: LaRank.cpp:854
int32_t nb_train
nb train
Definition: LaRank.h:480
#define ASSERT(x)
Definition: SGIO.h:201
float64_t getDual()
Definition: LaRank.cpp:870
class MultiClassSVM
Definition: MulticlassSVM.h:28
virtual ~CLaRank()
Definition: LaRank.cpp:604
double float64_t
Definition: common.h:50
bool get_batch_mode()
Definition: LaRank.h:372
virtual uint32_t getNumOutputs() const
Definition: LaRank.cpp:835
float float32_t
Definition: common.h:49
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual int32_t predict(int32_t x_id)
Definition: LaRank.cpp:779
the LaRank multiclass SVM machine
Definition: LaRank.h:306
The Kernel base class.
Definition: Kernel.h:158
void set_tau(float64_t t)
Definition: LaRank.h:376
std_hash_set< int32_t > classes
classes
Definition: LaRank.h:468
int32_t step
progess output
Definition: LaRank.h:487

SHOGUN 机器学习工具包 - 项目文档