MPDSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/MPDSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/lib/common.h>
00014 #include <shogun/mathematics/Math.h>
00015 
00016 using namespace shogun;
00017 
00018 CMPDSVM::CMPDSVM()
00019 : CSVM()
00020 {
00021 }
00022 
00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab)
00025 {
00026 }
00027 
00028 CMPDSVM::~CMPDSVM()
00029 {
00030 }
00031 
00032 bool CMPDSVM::train_machine(CFeatures* data)
00033 {
00034     ASSERT(labels);
00035     ASSERT(kernel);
00036 
00037     if (data)
00038     {
00039         if (labels->get_num_labels() != data->get_num_vectors())
00040             SG_ERROR("Number of training vectors does not match number of labels\n");
00041         kernel->init(data, data);
00042     }
00043     ASSERT(kernel->has_features());
00044 
00045     //const float64_t nu=0.32;
00046     const float64_t alpha_eps=1e-12;
00047     const float64_t eps=get_epsilon();
00048     const int64_t maxiter = 1L<<30;
00049     //const bool nustop=false;
00050     //const int32_t k=2;
00051     const int32_t n=labels->get_num_labels();
00052     ASSERT(n>0);
00053     //const float64_t d = 1.0/n/nu; //NUSVC
00054     const float64_t d = get_C1(); //CSVC
00055     const float64_t primaleps=eps;
00056     const float64_t dualeps=eps*n; //heuristic
00057     int64_t niter=0;
00058 
00059     kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00060     float64_t* alphas=SG_MALLOC(float64_t, n);
00061     float64_t* dalphas=SG_MALLOC(float64_t, n);
00062     //float64_t* hessres=SG_MALLOC(float64_t, 2*n);
00063     float64_t* hessres=SG_MALLOC(float64_t, n);
00064     //float64_t* F=SG_MALLOC(float64_t, 2*n);
00065     float64_t* F=SG_MALLOC(float64_t, n);
00066 
00067     //float64_t hessest[2]={0,0};
00068     //float64_t hstep[2];
00069     //float64_t etas[2]={0,0};
00070     //float64_t detas[2]={0,1}; //NUSVC
00071     float64_t etas=0;
00072     float64_t detas=0;   //CSVC
00073     float64_t hessest=0;
00074     float64_t hstep;
00075 
00076     const float64_t stopfac = 1;
00077 
00078     bool primalcool;
00079     bool dualcool;
00080 
00081     //if (nustop)
00082     //etas[1] = 1;
00083 
00084     for (int32_t i=0; i<n; i++)
00085     {
00086         alphas[i]=0;
00087         F[i]=labels->get_label(i);
00088         //F[i+n]=-1;
00089         hessres[i]=labels->get_label(i);
00090         //hessres[i+n]=-1;
00091         //dalphas[i]=F[i+n]*etas[1]; //NUSVC
00092         dalphas[i]=-1; //CSVC
00093     }
00094 
00095     // go ...
00096     while (niter++ < maxiter)
00097     {
00098         int32_t maxpidx=-1;
00099         float64_t maxpviol = -1;
00100         //float64_t maxdviol = CMath::abs(detas[0]);
00101         float64_t maxdviol = CMath::abs(detas);
00102         bool free_alpha=false;
00103 
00104         //if (CMath::abs(detas[1])> maxdviol)
00105         //maxdviol=CMath::abs(detas[1]);
00106 
00107         // compute kkt violations with correct sign ...
00108         for (int32_t i=0; i<n; i++)
00109         {
00110             float64_t v=CMath::abs(dalphas[i]);
00111 
00112             if (alphas[i] > 0 && alphas[i] < d)
00113                 free_alpha=true;
00114 
00115             if ( (dalphas[i]==0) ||
00116                     (alphas[i]==0 && dalphas[i] >0) ||
00117                     (alphas[i]==d && dalphas[i] <0)
00118                )
00119                 v=0;
00120 
00121             if (v > maxpviol)
00122             {
00123                 maxpviol=v;
00124                 maxpidx=i;
00125             } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
00126             else if (v == maxpviol) 
00127             {
00128                 if (kernel_cache->is_cached(i))
00129                     maxpidx=i;
00130             }
00131         }
00132 
00133         if (maxpidx<0 || maxdviol<0)
00134             SG_ERROR( "no violation no convergence, should not happen!\n");
00135 
00136         // ... and evaluate stopping conditions
00137         //if (nustop)
00138         //stopfac = CMath::max(etas[1], 1e-10);    
00139         //else
00140         //stopfac = 1;
00141 
00142         if (niter%10000 == 0)
00143         {
00144             float64_t obj=0;
00145 
00146             for (int32_t i=0; i<n; i++)
00147             {
00148                 obj-=alphas[i];
00149                 for (int32_t j=0; j<n; j++)
00150                     obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00151             }
00152 
00153             SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00154         }
00155 
00156         //for (int32_t i=0; i<n; i++)
00157         //  SG_DEBUG( "alphas:%f dalphas:%f\n", alphas[i], dalphas[i]);
00158 
00159         primalcool = (maxpviol < primaleps*stopfac);
00160         dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00161 
00162         // done?
00163         if (primalcool && dualcool)
00164         {
00165             if (!free_alpha)
00166                 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00167             else
00168                 SG_INFO( " done! #iter=%d\n", niter);
00169             break;
00170         }
00171 
00172 
00173         ASSERT(maxpidx>=0 && maxpidx<n);
00174         // hessian updates
00175         hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00176         //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
00177         //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
00178 
00179         hessest-=F[maxpidx]*hstep;
00180         //hessest[0]-=F[maxpidx]*hstep[0];
00181         //hessest[1]-=F[maxpidx+n]*hstep[1];
00182 
00183         // do primal updates ..
00184         float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00185 
00186         if (tmpalpha > d-alpha_eps) 
00187             tmpalpha = d;
00188 
00189         if (tmpalpha < 0+alpha_eps)
00190             tmpalpha = 0;
00191 
00192         // update alphas & dalphas & detas ...
00193         float64_t alphachange = tmpalpha - alphas[maxpidx];
00194         alphas[maxpidx] = tmpalpha;
00195 
00196         KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00197         for (int32_t i=0; i<n; i++)
00198         {
00199             hessres[i]+=h[i]*hstep;
00200             //hessres[i]+=h[i]*hstep[0];
00201             //hessres[i+n]+=h[i]*hstep[1];
00202             dalphas[i] +=h[i]*alphachange;
00203         }
00204         unlock_kernel_row(maxpidx);
00205 
00206         detas+=F[maxpidx]*alphachange;
00207         //detas[0]+=F[maxpidx]*alphachange;
00208         //detas[1]+=F[maxpidx+n]*alphachange;
00209 
00210         // if at primal minimum, do eta update ...            
00211         if (primalcool)
00212         {
00213             //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
00214             float64_t etachange = detas/hessest;
00215 
00216             etas+=etachange;        
00217             //etas[0]+=etachange[0];        
00218             //etas[1]+=etachange[1];        
00219 
00220             // update dalphas
00221             for (int32_t i=0; i<n; i++)
00222                 dalphas[i]+= F[i] * etachange;
00223             //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
00224         }
00225     }
00226 
00227     if (niter >= maxiter)
00228         SG_WARNING( "increase maxiter ... \n");
00229 
00230 
00231     int32_t nsv=0;
00232     for (int32_t i=0; i<n; i++)
00233     {
00234         if (alphas[i]>0)
00235             nsv++;
00236     }
00237 
00238 
00239     create_new_model(nsv);
00240     //set_bias(etas[0]/etas[1]);
00241     set_bias(etas);
00242 
00243     int32_t j=0;
00244     for (int32_t i=0; i<n; i++)
00245     {
00246         if (alphas[i]>0)
00247         {
00248             //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
00249             set_alpha(j, alphas[i]*labels->get_label(i));
00250             set_support_vector(j, i);
00251             j++;
00252         }
00253     }
00254     compute_svm_dual_objective();
00255     SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00256     SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00257 
00258     SG_FREE(alphas);
00259     SG_FREE(dalphas);
00260     SG_FREE(hessres);
00261     SG_FREE(F);
00262     delete kernel_cache;
00263 
00264     return true;
00265 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation