LaRank.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 // Main functions of the LaRank algorithm for soving Multiclass SVM
00003 // Copyright (C) 2008- Antoine Bordes
00004 // Shogun specific adjustments (w) 2009 Soeren Sonnenburg
00005 
00006 // This library is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 2.1 of the License, or (at your option) any later version.
00010 // 
00011 // This program is distributed in the hope that it will be useful,
00012 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014 // GNU General Public License for more details.
00015 // 
00016 // You should have received a copy of the GNU General Public License
00017 // along with this program; if not, write to the Free Software
00018 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
00019 //
00020 /***********************************************************************
00021  * 
00022  *  LUSH Lisp Universal Shell
00023  *    Copyright (C) 2002 Leon Bottou, Yann Le Cun, AT&T Corp, NECI.
00024  *  Includes parts of TL3:
00025  *    Copyright (C) 1987-1999 Leon Bottou and Neuristique.
00026  *  Includes selected parts of SN3.2:
00027  *    Copyright (C) 1991-2001 AT&T Corp.
00028  * 
00029  *  This program is free software; you can redistribute it and/or modify
00030  *  it under the terms of the GNU General Public License as published by
00031  *  the Free Software Foundation; either version 2 of the License, or
00032  *  (at your option) any later version.
00033  * 
00034  *  This program is distributed in the hope that it will be useful,
00035  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00036  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00037  *  GNU General Public License for more details.
00038  * 
00039  *  You should have received a copy of the GNU General Public License
00040  *  along with this program; if not, write to the Free Software
00041  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
00042  * 
00043  ***********************************************************************/
00044 
00045 /***********************************************************************
00046  * $Id: kcache.h,v 1.8 2007/01/25 22:42:09 leonb Exp $
00047  **********************************************************************/
00048 
00049 #ifndef LARANK_H
00050 #define LARANK_H
00051 
00052 #include <ctime>
00053 #include <vector>
00054 #include <algorithm>
00055 #include <sys/time.h>
00056 #include <ext/hash_map>
00057 #include <ext/hash_set>
00058 
00059 #define STDEXT_NAMESPACE __gnu_cxx
00060 #define std_hash_map STDEXT_NAMESPACE::hash_map
00061 #define std_hash_set STDEXT_NAMESPACE::hash_set
00062 
00063 #include "lib/io.h"
00064 #include "kernel/Kernel.h"
00065 #include "classifier/svm/MultiClassSVM.h"
00066 
00067 namespace shogun
00068 {
00069 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00070     struct larank_kcache_s;
00071     typedef struct larank_kcache_s larank_kcache_t;
00072     struct larank_kcache_s
00073     {
00074         CKernel* func;
00075         larank_kcache_t *prevbuddy;
00076         larank_kcache_t *nextbuddy;
00077         int64_t maxsize;
00078         int64_t cursize;
00079         int32_t l;
00080         int32_t *i2r;
00081         int32_t *r2i;
00082         int32_t maxrowlen;
00083         /* Rows */
00084         int32_t *rsize;
00085         float32_t *rdiag;
00086         float32_t **rdata;
00087         int32_t *rnext;
00088         int32_t *rprev;
00089         int32_t *qnext;
00090         int32_t *qprev;
00091     };
00092 
00093     /*
00094      ** OUTPUT: one per class of the raining set, keep tracks of support
00095      * vectors and their beta coefficients
00096      */
00097     class LaRankOutput
00098     {
00099         public:
00100             LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
00101         {
00102         }
00103             virtual ~LaRankOutput ()
00104             {
00105                 destroy();
00106             }
00107 
00108             // Initializing an output class (basically creating a kernel cache for it)
00109             void initialize (CKernel* kfunc, int64_t cache);
00110 
00111             // Destroying an output class (basically destroying the kernel cache)
00112             void destroy ();
00113 
00114             // !Important! Computing the score of a given input vector for the actual output
00115             float64_t computeScore (int32_t x_id);
00116 
00117             // !Important! Computing the gradient of a given input vector for the actual output           
00118             float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
00119 
00120             // Updating the solution in the actual output
00121             void update (int32_t x_id, float64_t lambda, float64_t gp);
00122 
00123             // Linking the cache of this output to the cache of an other "buddy" output
00124             // so that if a requested value is not found in this cache, you can
00125             // ask your buddy if it has it.                              
00126             void set_kernel_buddy (larank_kcache_t * bud);
00127 
00128             // Removing useless support vectors (for which beta=0)                
00129             int32_t cleanup ();
00130 
00131             // --- Below are information or "get" functions --- //
00132 
00133             //                            
00134             inline larank_kcache_t *getKernel () const
00135             {
00136                 return kernel;
00137             }
00138             //                            
00139             inline int32_t get_l () const
00140             {
00141                 return l;
00142             }
00143 
00144             //
00145             float64_t getW2 ();
00146 
00147             //
00148             float64_t getKii (int32_t x_id);
00149 
00150             //
00151             float64_t getBeta (int32_t x_id);
00152 
00153             //
00154             inline float32_t* getBetas () const
00155             {
00156                 return beta;
00157             }
00158 
00159             //
00160             float64_t getGradient (int32_t x_id);
00161 
00162             //
00163             bool isSupportVector (int32_t x_id) const;
00164 
00165             //
00166             int32_t getSV (float32_t* &sv) const;
00167 
00168         private:
00169             // the solution of LaRank relative to the actual class is stored in
00170             // this parameters
00171             float32_t* beta;        // Beta coefficiens
00172             float32_t* g;       // Strored gradient derivatives
00173             larank_kcache_t *kernel;    // Cache for kernel values
00174             int32_t l;          // Number of support vectors 
00175     };
00176 
00177     /*
00178      ** LARANKPATTERN: to keep track of the support patterns
00179      */
00180     class LaRankPattern
00181     {
00182         public:
00183             LaRankPattern (int32_t x_index, int32_t label) 
00184                 : x_id (x_index), y (label) {}
00185             LaRankPattern () 
00186                 : x_id (0) {}
00187 
00188             bool exists () const
00189             {
00190                 return x_id >= 0;
00191             }
00192 
00193             void clear ()
00194             {
00195                 x_id = -1;
00196             }
00197 
00198             int32_t x_id;
00199             int32_t y;
00200     };
00201 
00202     /*
00203      **  LARANKPATTERNS: the collection of support patterns
00204      */
00205     class LaRankPatterns
00206     {
00207         public:
00208             LaRankPatterns () {}
00209             ~LaRankPatterns () {}
00210 
00211             void insert (const LaRankPattern & pattern)
00212             {
00213                 if (!isPattern (pattern.x_id))
00214                 {
00215                     if (freeidx.size ())
00216                     {
00217                         std_hash_set < uint32_t >::iterator it = freeidx.begin ();
00218                         patterns[*it] = pattern;
00219                         x_id2rank[pattern.x_id] = *it;
00220                         freeidx.erase (it);
00221                     }
00222                     else
00223                     {
00224                         patterns.push_back (pattern);
00225                         x_id2rank[pattern.x_id] = patterns.size () - 1;
00226                     }
00227                 }
00228                 else
00229                 {
00230                     int32_t rank = getPatternRank (pattern.x_id);
00231                     patterns[rank] = pattern;
00232                 }
00233             }
00234 
00235             void remove (uint32_t i)
00236             {
00237                 x_id2rank[patterns[i].x_id] = 0;
00238                 patterns[i].clear ();
00239                 freeidx.insert (i);
00240             }
00241 
00242             bool empty () const
00243             {
00244                 return patterns.size () == freeidx.size ();
00245             }
00246 
00247             uint32_t size () const
00248             {
00249                 return patterns.size () - freeidx.size ();
00250             }
00251 
00252             LaRankPattern & sample ()
00253             {
00254                 ASSERT (!empty ());
00255                 while (true)
00256                 {
00257                     uint32_t r = CMath::random(0, patterns.size ()-1);
00258                     if (patterns[r].exists ())
00259                         return patterns[r];
00260                 }
00261                 return patterns[0];
00262             }
00263 
00264             uint32_t getPatternRank (int32_t x_id)
00265             {
00266                 return x_id2rank[x_id];
00267             }
00268 
00269             bool isPattern (int32_t x_id)
00270             {
00271                 return x_id2rank[x_id] != 0;
00272             }
00273 
00274             LaRankPattern & getPattern (int32_t x_id)
00275             {
00276                 uint32_t rank = x_id2rank[x_id];
00277                 return patterns[rank];
00278             }
00279 
00280             uint32_t maxcount () const
00281             {
00282                 return patterns.size ();
00283             }
00284 
00285             LaRankPattern & operator [] (uint32_t i)
00286             {
00287                 return patterns[i];
00288             }
00289 
00290             const LaRankPattern & operator [] (uint32_t i) const
00291             {
00292                 return patterns[i];
00293             }
00294 
00295         private:
00296             std_hash_set < uint32_t >freeidx;
00297             std::vector < LaRankPattern > patterns;
00298             std_hash_map < int32_t, uint32_t >x_id2rank;
00299     };
00300 
00301 
00302 #endif // DOXYGEN_SHOULD_SKIP_THIS
00303 
00304 
00305     /*
00306      ** MACHINE: the main thing, which is trained.
00307      */
00308     class CLaRank:  public CMultiClassSVM
00309     {
00310         public:
00311             CLaRank ();
00312 
00319             CLaRank(float64_t C, CKernel* k, CLabels* lab);
00320 
00321             virtual ~CLaRank ();
00322 
00323             bool train(CFeatures* data);
00324 
00325 
00326             // LEARNING FUNCTION: add new patterns and run optimization steps
00327             // selected with adaptative schedule
00328             virtual int32_t add (int32_t x_id, int32_t yi);
00329 
00330             // PREDICTION FUNCTION: main function in la_rank_classify
00331             virtual int32_t predict (int32_t x_id);
00332 
00333             virtual void destroy ();
00334 
00335             // Compute Duality gap (costly but used in stopping criteria in batch mode)                     
00336             virtual float64_t computeGap ();
00337 
00338             // Nuber of classes so far
00339             virtual uint32_t getNumOutputs () const;
00340 
00341             // Number of Support Vectors
00342             int32_t getNSV ();
00343 
00344             // Norm of the parameters vector
00345             float64_t computeW2 ();
00346 
00347             // Compute Dual objective value
00348             float64_t getDual ();
00349 
00354             virtual inline EClassifierType get_classifier_type() { return CT_LARANK; }
00355 
00357             inline virtual const char* get_name() const { return "LaRank"; }
00358 
00359             void set_batch_mode(bool enable) { batch_mode=enable; };
00360             bool get_batch_mode() { return batch_mode; };
00361             void set_tau(float64_t t) { tau=t; };
00362             float64_t get_tau() { return tau; };
00363 
00364 
00365         private:
00366             /*
00367              ** MAIN DARK OPTIMIZATION PROCESSES
00368              */
00369 
00370             // Hash Table used to store the different outputs
00371             typedef std_hash_map < int32_t, LaRankOutput > outputhash_t;    // class index -> LaRankOutput
00372 
00373 
00374             outputhash_t outputs;
00375             LaRankOutput *getOutput (int32_t index);
00376 
00377             // 
00378             LaRankPatterns patterns;
00379 
00380             // Parameters
00381             int32_t nb_seen_examples;
00382             int32_t nb_removed;
00383 
00384             // Numbers of each operation performed so far
00385             int32_t n_pro;
00386             int32_t n_rep;
00387             int32_t n_opt;
00388 
00389             // Running estimates for each operations 
00390             float64_t w_pro;
00391             float64_t w_rep;
00392             float64_t w_opt;
00393 
00394             int32_t y0;
00395             float64_t dual;
00396 
00397             struct outputgradient_t
00398             {
00399                 outputgradient_t (int32_t result_output, float64_t result_gradient)
00400                     : output (result_output), gradient (result_gradient) {}
00401                 outputgradient_t ()
00402                     : output (0), gradient (0) {}
00403 
00404                 int32_t output;
00405                 float64_t gradient;
00406 
00407                 bool operator < (const outputgradient_t & og) const
00408                 {
00409                     return gradient > og.gradient;
00410                 }
00411             };
00412 
00413             //3 types of operations in LaRank               
00414             enum process_type
00415             {
00416                 processNew,
00417                 processOld,
00418                 processOptimize
00419             };
00420 
00421             struct process_return_t
00422             {
00423                 process_return_t (float64_t dual, int32_t yprediction) 
00424                     : dual_increase (dual), ypred (yprediction) {}
00425                 process_return_t () {}
00426                 float64_t dual_increase;
00427                 int32_t ypred;
00428             };
00429 
00430             // IMPORTANT Main SMO optimization step
00431             process_return_t process (const LaRankPattern & pattern, process_type ptype);
00432 
00433             // ProcessOld
00434             float64_t reprocess ();
00435 
00436             // Optimize
00437             float64_t optimize ();
00438 
00439             // remove patterns and return the number of patterns that were removed
00440             uint32_t cleanup ();
00441 
00442         protected:
00443 
00444             std_hash_set < int32_t >classes;
00445 
00446             inline uint32_t class_count () const
00447             {
00448                 return classes.size ();
00449             }
00450 
00451             float64_t tau;
00452             int32_t nb_train;
00453             int64_t cache;
00454             // whether to use online learning or batch training
00455             bool batch_mode;
00456 
00457             //progess output
00458             int32_t step;
00459     };
00460 }
00461 #endif // LARANK_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation