MPDSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/MPDSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/lib/common.h>
00014 #include <shogun/mathematics/Math.h>
00015 
00016 using namespace shogun;
00017 
00018 CMPDSVM::CMPDSVM()
00019 : CSVM()
00020 {
00021 }
00022 
00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab)
00025 {
00026 }
00027 
00028 CMPDSVM::~CMPDSVM()
00029 {
00030 }
00031 
00032 bool CMPDSVM::train_machine(CFeatures* data)
00033 {
00034     ASSERT(m_labels);
00035     ASSERT(m_labels->get_label_type() == LT_BINARY);
00036     ASSERT(kernel);
00037 
00038     if (data)
00039     {
00040         if (m_labels->get_num_labels() != data->get_num_vectors())
00041             SG_ERROR("Number of training vectors does not match number of labels\n");
00042         kernel->init(data, data);
00043     }
00044     ASSERT(kernel->has_features());
00045 
00046     //const float64_t nu=0.32;
00047     const float64_t alpha_eps=1e-12;
00048     const float64_t eps=get_epsilon();
00049     const int64_t maxiter = 1L<<30;
00050     //const bool nustop=false;
00051     //const int32_t k=2;
00052     const int32_t n=m_labels->get_num_labels();
00053     ASSERT(n>0);
00054     //const float64_t d = 1.0/n/nu; //NUSVC
00055     const float64_t d = get_C1(); //CSVC
00056     const float64_t primaleps=eps;
00057     const float64_t dualeps=eps*n; //heuristic
00058     int64_t niter=0;
00059 
00060     kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00061     float64_t* alphas=SG_MALLOC(float64_t, n);
00062     float64_t* dalphas=SG_MALLOC(float64_t, n);
00063     //float64_t* hessres=SG_MALLOC(float64_t, 2*n);
00064     float64_t* hessres=SG_MALLOC(float64_t, n);
00065     //float64_t* F=SG_MALLOC(float64_t, 2*n);
00066     float64_t* F=SG_MALLOC(float64_t, n);
00067 
00068     //float64_t hessest[2]={0,0};
00069     //float64_t hstep[2];
00070     //float64_t etas[2]={0,0};
00071     //float64_t detas[2]={0,1}; //NUSVC
00072     float64_t etas=0;
00073     float64_t detas=0;   //CSVC
00074     float64_t hessest=0;
00075     float64_t hstep;
00076 
00077     const float64_t stopfac = 1;
00078 
00079     bool primalcool;
00080     bool dualcool;
00081 
00082     //if (nustop)
00083     //etas[1] = 1;
00084 
00085     for (int32_t i=0; i<n; i++)
00086     {
00087         alphas[i]=0;
00088         F[i]=((CBinaryLabels*) m_labels)->get_label(i);
00089         //F[i+n]=-1;
00090         hessres[i]=((CBinaryLabels*) m_labels)->get_label(i);
00091         //hessres[i+n]=-1;
00092         //dalphas[i]=F[i+n]*etas[1]; //NUSVC
00093         dalphas[i]=-1; //CSVC
00094     }
00095 
00096     // go ...
00097     while (niter++ < maxiter)
00098     {
00099         int32_t maxpidx=-1;
00100         float64_t maxpviol = -1;
00101         //float64_t maxdviol = CMath::abs(detas[0]);
00102         float64_t maxdviol = CMath::abs(detas);
00103         bool free_alpha=false;
00104 
00105         //if (CMath::abs(detas[1])> maxdviol)
00106         //maxdviol=CMath::abs(detas[1]);
00107 
00108         // compute kkt violations with correct sign ...
00109         for (int32_t i=0; i<n; i++)
00110         {
00111             float64_t v=CMath::abs(dalphas[i]);
00112 
00113             if (alphas[i] > 0 && alphas[i] < d)
00114                 free_alpha=true;
00115 
00116             if ( (dalphas[i]==0) ||
00117                     (alphas[i]==0 && dalphas[i] >0) ||
00118                     (alphas[i]==d && dalphas[i] <0)
00119                )
00120                 v=0;
00121 
00122             if (v > maxpviol)
00123             {
00124                 maxpviol=v;
00125                 maxpidx=i;
00126             } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
00127             else if (v == maxpviol)
00128             {
00129                 if (kernel_cache->is_cached(i))
00130                     maxpidx=i;
00131             }
00132         }
00133 
00134         if (maxpidx<0 || maxdviol<0)
00135             SG_ERROR( "no violation no convergence, should not happen!\n");
00136 
00137         // ... and evaluate stopping conditions
00138         //if (nustop)
00139         //stopfac = CMath::max(etas[1], 1e-10);
00140         //else
00141         //stopfac = 1;
00142 
00143         if (niter%10000 == 0)
00144         {
00145             float64_t obj=0;
00146 
00147             for (int32_t i=0; i<n; i++)
00148             {
00149                 obj-=alphas[i];
00150                 for (int32_t j=0; j<n; j++)
00151                     obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00152             }
00153 
00154             SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00155         }
00156 
00157         //for (int32_t i=0; i<n; i++)
00158         //  SG_DEBUG( "alphas:%f dalphas:%f\n", alphas[i], dalphas[i]);
00159 
00160         primalcool = (maxpviol < primaleps*stopfac);
00161         dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00162 
00163         // done?
00164         if (primalcool && dualcool)
00165         {
00166             if (!free_alpha)
00167                 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00168             else
00169                 SG_INFO( " done! #iter=%d\n", niter);
00170             break;
00171         }
00172 
00173 
00174         ASSERT(maxpidx>=0 && maxpidx<n);
00175         // hessian updates
00176         hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00177         //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
00178         //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
00179 
00180         hessest-=F[maxpidx]*hstep;
00181         //hessest[0]-=F[maxpidx]*hstep[0];
00182         //hessest[1]-=F[maxpidx+n]*hstep[1];
00183 
00184         // do primal updates ..
00185         float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00186 
00187         if (tmpalpha > d-alpha_eps)
00188             tmpalpha = d;
00189 
00190         if (tmpalpha < 0+alpha_eps)
00191             tmpalpha = 0;
00192 
00193         // update alphas & dalphas & detas ...
00194         float64_t alphachange = tmpalpha - alphas[maxpidx];
00195         alphas[maxpidx] = tmpalpha;
00196 
00197         KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00198         for (int32_t i=0; i<n; i++)
00199         {
00200             hessres[i]+=h[i]*hstep;
00201             //hessres[i]+=h[i]*hstep[0];
00202             //hessres[i+n]+=h[i]*hstep[1];
00203             dalphas[i] +=h[i]*alphachange;
00204         }
00205         unlock_kernel_row(maxpidx);
00206 
00207         detas+=F[maxpidx]*alphachange;
00208         //detas[0]+=F[maxpidx]*alphachange;
00209         //detas[1]+=F[maxpidx+n]*alphachange;
00210 
00211         // if at primal minimum, do eta update ...
00212         if (primalcool)
00213         {
00214             //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
00215             float64_t etachange = detas/hessest;
00216 
00217             etas+=etachange;
00218             //etas[0]+=etachange[0];
00219             //etas[1]+=etachange[1];
00220 
00221             // update dalphas
00222             for (int32_t i=0; i<n; i++)
00223                 dalphas[i]+= F[i] * etachange;
00224             //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
00225         }
00226     }
00227 
00228     if (niter >= maxiter)
00229         SG_WARNING( "increase maxiter ... \n");
00230 
00231 
00232     int32_t nsv=0;
00233     for (int32_t i=0; i<n; i++)
00234     {
00235         if (alphas[i]>0)
00236             nsv++;
00237     }
00238 
00239 
00240     create_new_model(nsv);
00241     //set_bias(etas[0]/etas[1]);
00242     set_bias(etas);
00243 
00244     int32_t j=0;
00245     for (int32_t i=0; i<n; i++)
00246     {
00247         if (alphas[i]>0)
00248         {
00249             //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
00250             set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i));
00251             set_support_vector(j, i);
00252             j++;
00253         }
00254     }
00255     compute_svm_dual_objective();
00256     SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00257     SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00258 
00259     SG_FREE(alphas);
00260     SG_FREE(dalphas);
00261     SG_FREE(hessres);
00262     SG_FREE(F);
00263     delete kernel_cache;
00264 
00265     return true;
00266 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation