Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/MPDSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/lib/common.h>
00014 #include <shogun/mathematics/Math.h>
00015
00016 using namespace shogun;
00017
00018 CMPDSVM::CMPDSVM()
00019 : CSVM()
00020 {
00021 }
00022
00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab)
00025 {
00026 }
00027
00028 CMPDSVM::~CMPDSVM()
00029 {
00030 }
00031
00032 bool CMPDSVM::train_machine(CFeatures* data)
00033 {
00034 ASSERT(m_labels);
00035 ASSERT(m_labels->get_label_type() == LT_BINARY);
00036 ASSERT(kernel);
00037
00038 if (data)
00039 {
00040 if (m_labels->get_num_labels() != data->get_num_vectors())
00041 SG_ERROR("Number of training vectors does not match number of labels\n");
00042 kernel->init(data, data);
00043 }
00044 ASSERT(kernel->has_features());
00045
00046
00047 const float64_t alpha_eps=1e-12;
00048 const float64_t eps=get_epsilon();
00049 const int64_t maxiter = 1L<<30;
00050
00051
00052 const int32_t n=m_labels->get_num_labels();
00053 ASSERT(n>0);
00054
00055 const float64_t d = get_C1();
00056 const float64_t primaleps=eps;
00057 const float64_t dualeps=eps*n;
00058 int64_t niter=0;
00059
00060 kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00061 float64_t* alphas=SG_MALLOC(float64_t, n);
00062 float64_t* dalphas=SG_MALLOC(float64_t, n);
00063
00064 float64_t* hessres=SG_MALLOC(float64_t, n);
00065
00066 float64_t* F=SG_MALLOC(float64_t, n);
00067
00068
00069
00070
00071
00072 float64_t etas=0;
00073 float64_t detas=0;
00074 float64_t hessest=0;
00075 float64_t hstep;
00076
00077 const float64_t stopfac = 1;
00078
00079 bool primalcool;
00080 bool dualcool;
00081
00082
00083
00084
00085 for (int32_t i=0; i<n; i++)
00086 {
00087 alphas[i]=0;
00088 F[i]=((CBinaryLabels*) m_labels)->get_label(i);
00089
00090 hessres[i]=((CBinaryLabels*) m_labels)->get_label(i);
00091
00092
00093 dalphas[i]=-1;
00094 }
00095
00096
00097 while (niter++ < maxiter)
00098 {
00099 int32_t maxpidx=-1;
00100 float64_t maxpviol = -1;
00101
00102 float64_t maxdviol = CMath::abs(detas);
00103 bool free_alpha=false;
00104
00105
00106
00107
00108
00109 for (int32_t i=0; i<n; i++)
00110 {
00111 float64_t v=CMath::abs(dalphas[i]);
00112
00113 if (alphas[i] > 0 && alphas[i] < d)
00114 free_alpha=true;
00115
00116 if ( (dalphas[i]==0) ||
00117 (alphas[i]==0 && dalphas[i] >0) ||
00118 (alphas[i]==d && dalphas[i] <0)
00119 )
00120 v=0;
00121
00122 if (v > maxpviol)
00123 {
00124 maxpviol=v;
00125 maxpidx=i;
00126 }
00127 else if (v == maxpviol)
00128 {
00129 if (kernel_cache->is_cached(i))
00130 maxpidx=i;
00131 }
00132 }
00133
00134 if (maxpidx<0 || maxdviol<0)
00135 SG_ERROR( "no violation no convergence, should not happen!\n");
00136
00137
00138
00139
00140
00141
00142
00143 if (niter%10000 == 0)
00144 {
00145 float64_t obj=0;
00146
00147 for (int32_t i=0; i<n; i++)
00148 {
00149 obj-=alphas[i];
00150 for (int32_t j=0; j<n; j++)
00151 obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00152 }
00153
00154 SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00155 }
00156
00157
00158
00159
00160 primalcool = (maxpviol < primaleps*stopfac);
00161 dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00162
00163
00164 if (primalcool && dualcool)
00165 {
00166 if (!free_alpha)
00167 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00168 else
00169 SG_INFO( " done! #iter=%d\n", niter);
00170 break;
00171 }
00172
00173
00174 ASSERT(maxpidx>=0 && maxpidx<n);
00175
00176 hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00177
00178
00179
00180 hessest-=F[maxpidx]*hstep;
00181
00182
00183
00184
00185 float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00186
00187 if (tmpalpha > d-alpha_eps)
00188 tmpalpha = d;
00189
00190 if (tmpalpha < 0+alpha_eps)
00191 tmpalpha = 0;
00192
00193
00194 float64_t alphachange = tmpalpha - alphas[maxpidx];
00195 alphas[maxpidx] = tmpalpha;
00196
00197 KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00198 for (int32_t i=0; i<n; i++)
00199 {
00200 hessres[i]+=h[i]*hstep;
00201
00202
00203 dalphas[i] +=h[i]*alphachange;
00204 }
00205 unlock_kernel_row(maxpidx);
00206
00207 detas+=F[maxpidx]*alphachange;
00208
00209
00210
00211
00212 if (primalcool)
00213 {
00214
00215 float64_t etachange = detas/hessest;
00216
00217 etas+=etachange;
00218
00219
00220
00221
00222 for (int32_t i=0; i<n; i++)
00223 dalphas[i]+= F[i] * etachange;
00224
00225 }
00226 }
00227
00228 if (niter >= maxiter)
00229 SG_WARNING( "increase maxiter ... \n");
00230
00231
00232 int32_t nsv=0;
00233 for (int32_t i=0; i<n; i++)
00234 {
00235 if (alphas[i]>0)
00236 nsv++;
00237 }
00238
00239
00240 create_new_model(nsv);
00241
00242 set_bias(etas);
00243
00244 int32_t j=0;
00245 for (int32_t i=0; i<n; i++)
00246 {
00247 if (alphas[i]>0)
00248 {
00249
00250 set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i));
00251 set_support_vector(j, i);
00252 j++;
00253 }
00254 }
00255 compute_svm_dual_objective();
00256 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00257 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00258
00259 SG_FREE(alphas);
00260 SG_FREE(dalphas);
00261 SG_FREE(hessres);
00262 SG_FREE(F);
00263 delete kernel_cache;
00264
00265 return true;
00266 }