Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/MPDSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/lib/common.h>
00014 #include <shogun/mathematics/Math.h>
00015
00016 using namespace shogun;
00017
00018 CMPDSVM::CMPDSVM()
00019 : CSVM()
00020 {
00021 }
00022
00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab)
00025 {
00026 }
00027
00028 CMPDSVM::~CMPDSVM()
00029 {
00030 }
00031
00032 bool CMPDSVM::train_machine(CFeatures* data)
00033 {
00034 ASSERT(labels);
00035 ASSERT(kernel);
00036
00037 if (data)
00038 {
00039 if (labels->get_num_labels() != data->get_num_vectors())
00040 SG_ERROR("Number of training vectors does not match number of labels\n");
00041 kernel->init(data, data);
00042 }
00043 ASSERT(kernel->has_features());
00044
00045
00046 const float64_t alpha_eps=1e-12;
00047 const float64_t eps=get_epsilon();
00048 const int64_t maxiter = 1L<<30;
00049
00050
00051 const int32_t n=labels->get_num_labels();
00052 ASSERT(n>0);
00053
00054 const float64_t d = get_C1();
00055 const float64_t primaleps=eps;
00056 const float64_t dualeps=eps*n;
00057 int64_t niter=0;
00058
00059 kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00060 float64_t* alphas=SG_MALLOC(float64_t, n);
00061 float64_t* dalphas=SG_MALLOC(float64_t, n);
00062
00063 float64_t* hessres=SG_MALLOC(float64_t, n);
00064
00065 float64_t* F=SG_MALLOC(float64_t, n);
00066
00067
00068
00069
00070
00071 float64_t etas=0;
00072 float64_t detas=0;
00073 float64_t hessest=0;
00074 float64_t hstep;
00075
00076 const float64_t stopfac = 1;
00077
00078 bool primalcool;
00079 bool dualcool;
00080
00081
00082
00083
00084 for (int32_t i=0; i<n; i++)
00085 {
00086 alphas[i]=0;
00087 F[i]=labels->get_label(i);
00088
00089 hessres[i]=labels->get_label(i);
00090
00091
00092 dalphas[i]=-1;
00093 }
00094
00095
00096 while (niter++ < maxiter)
00097 {
00098 int32_t maxpidx=-1;
00099 float64_t maxpviol = -1;
00100
00101 float64_t maxdviol = CMath::abs(detas);
00102 bool free_alpha=false;
00103
00104
00105
00106
00107
00108 for (int32_t i=0; i<n; i++)
00109 {
00110 float64_t v=CMath::abs(dalphas[i]);
00111
00112 if (alphas[i] > 0 && alphas[i] < d)
00113 free_alpha=true;
00114
00115 if ( (dalphas[i]==0) ||
00116 (alphas[i]==0 && dalphas[i] >0) ||
00117 (alphas[i]==d && dalphas[i] <0)
00118 )
00119 v=0;
00120
00121 if (v > maxpviol)
00122 {
00123 maxpviol=v;
00124 maxpidx=i;
00125 }
00126 else if (v == maxpviol)
00127 {
00128 if (kernel_cache->is_cached(i))
00129 maxpidx=i;
00130 }
00131 }
00132
00133 if (maxpidx<0 || maxdviol<0)
00134 SG_ERROR( "no violation no convergence, should not happen!\n");
00135
00136
00137
00138
00139
00140
00141
00142 if (niter%10000 == 0)
00143 {
00144 float64_t obj=0;
00145
00146 for (int32_t i=0; i<n; i++)
00147 {
00148 obj-=alphas[i];
00149 for (int32_t j=0; j<n; j++)
00150 obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00151 }
00152
00153 SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00154 }
00155
00156
00157
00158
00159 primalcool = (maxpviol < primaleps*stopfac);
00160 dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00161
00162
00163 if (primalcool && dualcool)
00164 {
00165 if (!free_alpha)
00166 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00167 else
00168 SG_INFO( " done! #iter=%d\n", niter);
00169 break;
00170 }
00171
00172
00173 ASSERT(maxpidx>=0 && maxpidx<n);
00174
00175 hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00176
00177
00178
00179 hessest-=F[maxpidx]*hstep;
00180
00181
00182
00183
00184 float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00185
00186 if (tmpalpha > d-alpha_eps)
00187 tmpalpha = d;
00188
00189 if (tmpalpha < 0+alpha_eps)
00190 tmpalpha = 0;
00191
00192
00193 float64_t alphachange = tmpalpha - alphas[maxpidx];
00194 alphas[maxpidx] = tmpalpha;
00195
00196 KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00197 for (int32_t i=0; i<n; i++)
00198 {
00199 hessres[i]+=h[i]*hstep;
00200
00201
00202 dalphas[i] +=h[i]*alphachange;
00203 }
00204 unlock_kernel_row(maxpidx);
00205
00206 detas+=F[maxpidx]*alphachange;
00207
00208
00209
00210
00211 if (primalcool)
00212 {
00213
00214 float64_t etachange = detas/hessest;
00215
00216 etas+=etachange;
00217
00218
00219
00220
00221 for (int32_t i=0; i<n; i++)
00222 dalphas[i]+= F[i] * etachange;
00223
00224 }
00225 }
00226
00227 if (niter >= maxiter)
00228 SG_WARNING( "increase maxiter ... \n");
00229
00230
00231 int32_t nsv=0;
00232 for (int32_t i=0; i<n; i++)
00233 {
00234 if (alphas[i]>0)
00235 nsv++;
00236 }
00237
00238
00239 create_new_model(nsv);
00240
00241 set_bias(etas);
00242
00243 int32_t j=0;
00244 for (int32_t i=0; i<n; i++)
00245 {
00246 if (alphas[i]>0)
00247 {
00248
00249 set_alpha(j, alphas[i]*labels->get_label(i));
00250 set_support_vector(j, i);
00251 j++;
00252 }
00253 }
00254 compute_svm_dual_objective();
00255 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00256 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00257
00258 SG_FREE(alphas);
00259 SG_FREE(dalphas);
00260 SG_FREE(hessres);
00261 SG_FREE(F);
00262 delete kernel_cache;
00263
00264 return true;
00265 }