00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifdef USE_SVMLIGHT
00012 #include <shogun/classifier/svm/SVMLightOneClass.h>
00013 #endif //USE_SVMLIGHT
00014
00015 #include <shogun/kernel/Kernel.h>
00016 #include <shogun/multiclass/ScatterSVM.h>
00017 #include <shogun/kernel/normalizer/ScatterKernelNormalizer.h>
00018 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00019 #include <shogun/io/SGIO.h>
00020
00021 using namespace shogun;
00022
00023 CScatterSVM::CScatterSVM()
00024 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(NO_BIAS_LIBSVM),
00025 model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00026 {
00027 SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n");
00028 }
00029
00030 CScatterSVM::CScatterSVM(SCATTER_TYPE type)
00031 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL),
00032 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00033 {
00034 }
00035
00036 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00037 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
00038 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00039 {
00040 }
00041
00042 CScatterSVM::~CScatterSVM()
00043 {
00044 SG_FREE(norm_wc);
00045 SG_FREE(norm_wcw);
00046 }
00047
00048 bool CScatterSVM::train_machine(CFeatures* data)
00049 {
00050 ASSERT(m_labels && m_labels->get_num_labels());
00051 ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00052
00053 m_num_classes = m_multiclass_strategy->get_num_classes();
00054 int32_t num_vectors = m_labels->get_num_labels();
00055
00056 if (data)
00057 {
00058 if (m_labels->get_num_labels() != data->get_num_vectors())
00059 SG_ERROR("Number of training vectors does not match number of labels\n");
00060 m_kernel->init(data, data);
00061 }
00062
00063 int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
00064 SGVector<int32_t>::fill_vector(numc, m_num_classes, 0);
00065
00066 for (int32_t i=0; i<num_vectors; i++)
00067 numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++;
00068
00069 int32_t Nc=0;
00070 int32_t Nmin=num_vectors;
00071 for (int32_t i=0; i<m_num_classes; i++)
00072 {
00073 if (numc[i]>0)
00074 {
00075 Nc++;
00076 Nmin=CMath::min(Nmin, numc[i]);
00077 }
00078
00079 }
00080 SG_FREE(numc);
00081 m_num_classes=m_num_classes;
00082
00083 bool result=false;
00084
00085 if (scatter_type==NO_BIAS_LIBSVM)
00086 {
00087 result=train_no_bias_libsvm();
00088 }
00089 #ifdef USE_SVMLIGHT
00090 else if (scatter_type==NO_BIAS_SVMLIGHT)
00091 {
00092 result=train_no_bias_svmlight();
00093 }
00094 #endif //USE_SVMLIGHT
00095 else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2)
00096 {
00097 float64_t nu_min=((float64_t) Nc)/num_vectors;
00098 float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
00099
00100 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00101
00102 if (get_nu()<nu_min || get_nu()>nu_max)
00103 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00104
00105 result=train_testrule12();
00106 }
00107 else
00108 SG_ERROR("Unknown Scatter type\n");
00109
00110 return result;
00111 }
00112
00113 bool CScatterSVM::train_no_bias_libsvm()
00114 {
00115 struct svm_node* x_space;
00116
00117 problem.l=m_labels->get_num_labels();
00118 SG_INFO( "%d trainlabels\n", problem.l);
00119
00120 problem.y=SG_MALLOC(float64_t, problem.l);
00121 problem.x=SG_MALLOC(struct svm_node*, problem.l);
00122 x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00123
00124 for (int32_t i=0; i<problem.l; i++)
00125 {
00126 problem.y[i]=+1;
00127 problem.x[i]=&x_space[2*i];
00128 x_space[2*i].index=i;
00129 x_space[2*i+1].index=-1;
00130 }
00131
00132 int32_t weights_label[2]={-1,+1};
00133 float64_t weights[2]={1.0,get_C()/get_C()};
00134
00135 ASSERT(m_kernel && m_kernel->has_features());
00136 ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
00137
00138 param.svm_type=C_SVC;
00139 param.kernel_type = LINEAR;
00140 param.degree = 3;
00141 param.gamma = 0;
00142 param.coef0 = 0;
00143 param.nu = get_nu();
00144 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00145 m_kernel->set_normalizer(new CScatterKernelNormalizer(
00146 m_num_classes-1, -1, m_labels, prev_normalizer));
00147 param.kernel=m_kernel;
00148 param.cache_size = m_kernel->get_cache_size();
00149 param.C = 0;
00150 param.eps = get_epsilon();
00151 param.p = 0.1;
00152 param.shrinking = 0;
00153 param.nr_weight = 2;
00154 param.weight_label = weights_label;
00155 param.weight = weights;
00156 param.nr_class=m_num_classes;
00157 param.use_bias = svm_proto()->get_bias_enabled();
00158
00159 const char* error_msg = svm_check_parameter(&problem,¶m);
00160
00161 if(error_msg)
00162 SG_ERROR("Error: %s\n",error_msg);
00163
00164 model = svm_train(&problem, ¶m);
00165 m_kernel->set_normalizer(prev_normalizer);
00166 SG_UNREF(prev_normalizer);
00167
00168 if (model)
00169 {
00170 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00171
00172 ASSERT(model->nr_class==m_num_classes);
00173 create_multiclass_svm(m_num_classes);
00174
00175 rho=model->rho[0];
00176
00177 SG_FREE(norm_wcw);
00178 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00179
00180 for (int32_t i=0; i<m_num_classes; i++)
00181 {
00182 int32_t num_sv=model->nSV[i];
00183
00184 CSVM* svm=new CSVM(num_sv);
00185 svm->set_bias(model->rho[i+1]);
00186 norm_wcw[i]=model->normwcw[i];
00187
00188
00189 for (int32_t j=0; j<num_sv; j++)
00190 {
00191 svm->set_alpha(j, model->sv_coef[i][j]);
00192 svm->set_support_vector(j, model->SV[i][j].index);
00193 }
00194
00195 set_svm(i, svm);
00196 }
00197
00198 SG_FREE(problem.x);
00199 SG_FREE(problem.y);
00200 SG_FREE(x_space);
00201 for (int32_t i=0; i<m_num_classes; i++)
00202 {
00203 SG_FREE(model->SV[i]);
00204 model->SV[i]=NULL;
00205 }
00206 svm_destroy_model(model);
00207
00208 if (scatter_type==TEST_RULE2)
00209 compute_norm_wc();
00210
00211 model=NULL;
00212 return true;
00213 }
00214 else
00215 return false;
00216 }
00217
00218 #ifdef USE_SVMLIGHT
00219 bool CScatterSVM::train_no_bias_svmlight()
00220 {
00221 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00222 CScatterKernelNormalizer* n=new CScatterKernelNormalizer(
00223 m_num_classes-1, -1, m_labels, prev_normalizer);
00224 m_kernel->set_normalizer(n);
00225 m_kernel->init_normalizer();
00226
00227 CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel);
00228 light->set_linadd_enabled(false);
00229 light->train();
00230
00231 SG_FREE(norm_wcw);
00232 norm_wcw = SG_MALLOC(float64_t, m_num_classes);
00233
00234 int32_t num_sv=light->get_num_support_vectors();
00235 svm_proto()->create_new_model(num_sv);
00236
00237 for (int32_t i=0; i<num_sv; i++)
00238 {
00239 svm_proto()->set_alpha(i, light->get_alpha(i));
00240 svm_proto()->set_support_vector(i, light->get_support_vector(i));
00241 }
00242
00243 m_kernel->set_normalizer(prev_normalizer);
00244 return true;
00245 }
00246 #endif //USE_SVMLIGHT
00247
00248 bool CScatterSVM::train_testrule12()
00249 {
00250 struct svm_node* x_space;
00251 problem.l=m_labels->get_num_labels();
00252 SG_INFO( "%d trainlabels\n", problem.l);
00253
00254 problem.y=SG_MALLOC(float64_t, problem.l);
00255 problem.x=SG_MALLOC(struct svm_node*, problem.l);
00256 x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00257
00258 for (int32_t i=0; i<problem.l; i++)
00259 {
00260 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00261 problem.x[i]=&x_space[2*i];
00262 x_space[2*i].index=i;
00263 x_space[2*i+1].index=-1;
00264 }
00265
00266 int32_t weights_label[2]={-1,+1};
00267 float64_t weights[2]={1.0,get_C()/get_C()};
00268
00269 ASSERT(m_kernel && m_kernel->has_features());
00270 ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
00271
00272 param.svm_type=NU_MULTICLASS_SVC;
00273 param.kernel_type = LINEAR;
00274 param.degree = 3;
00275 param.gamma = 0;
00276 param.coef0 = 0;
00277 param.nu = get_nu();
00278 param.kernel=m_kernel;
00279 param.cache_size = m_kernel->get_cache_size();
00280 param.C = 0;
00281 param.eps = get_epsilon();
00282 param.p = 0.1;
00283 param.shrinking = 0;
00284 param.nr_weight = 2;
00285 param.weight_label = weights_label;
00286 param.weight = weights;
00287 param.nr_class=m_num_classes;
00288 param.use_bias = svm_proto()->get_bias_enabled();
00289
00290 const char* error_msg = svm_check_parameter(&problem,¶m);
00291
00292 if(error_msg)
00293 SG_ERROR("Error: %s\n",error_msg);
00294
00295 model = svm_train(&problem, ¶m);
00296
00297 if (model)
00298 {
00299 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00300
00301 ASSERT(model->nr_class==m_num_classes);
00302 create_multiclass_svm(m_num_classes);
00303
00304 rho=model->rho[0];
00305
00306 SG_FREE(norm_wcw);
00307 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00308
00309 for (int32_t i=0; i<m_num_classes; i++)
00310 {
00311 int32_t num_sv=model->nSV[i];
00312
00313 CSVM* svm=new CSVM(num_sv);
00314 svm->set_bias(model->rho[i+1]);
00315 norm_wcw[i]=model->normwcw[i];
00316
00317
00318 for (int32_t j=0; j<num_sv; j++)
00319 {
00320 svm->set_alpha(j, model->sv_coef[i][j]);
00321 svm->set_support_vector(j, model->SV[i][j].index);
00322 }
00323
00324 set_svm(i, svm);
00325 }
00326
00327 SG_FREE(problem.x);
00328 SG_FREE(problem.y);
00329 SG_FREE(x_space);
00330 for (int32_t i=0; i<m_num_classes; i++)
00331 {
00332 SG_FREE(model->SV[i]);
00333 model->SV[i]=NULL;
00334 }
00335 svm_destroy_model(model);
00336
00337 if (scatter_type==TEST_RULE2)
00338 compute_norm_wc();
00339
00340 model=NULL;
00341 return true;
00342 }
00343 else
00344 return false;
00345 }
00346
00347 void CScatterSVM::compute_norm_wc()
00348 {
00349 SG_FREE(norm_wc);
00350 norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements());
00351 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00352 norm_wc[i]=0;
00353
00354
00355 for (int c=0; c<m_machines->get_num_elements(); c++)
00356 {
00357 CSVM* svm=get_svm(c);
00358 int32_t num_sv = svm->get_num_support_vectors();
00359
00360 for (int32_t i=0; i<num_sv; i++)
00361 {
00362 int32_t ii=svm->get_support_vector(i);
00363 for (int32_t j=0; j<num_sv; j++)
00364 {
00365 int32_t jj=svm->get_support_vector(j);
00366 norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j);
00367 }
00368 }
00369 }
00370
00371 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00372 norm_wc[i]=CMath::sqrt(norm_wc[i]);
00373
00374 SGVector<float64_t>::display_vector(norm_wc, m_machines->get_num_elements(), "norm_wc");
00375 }
00376
00377 CLabels* CScatterSVM::classify_one_vs_rest()
00378 {
00379 CMulticlassLabels* output=NULL;
00380 if (!m_kernel)
00381 {
00382 SG_ERROR( "SVM can not proceed without kernel!\n");
00383 return NULL;
00384 }
00385
00386 if (!( m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs()))
00387 return NULL;
00388
00389 int32_t num_vectors=m_kernel->get_num_vec_rhs();
00390
00391 output=new CMulticlassLabels(num_vectors);
00392 SG_REF(output);
00393
00394 if (scatter_type == TEST_RULE1)
00395 {
00396 ASSERT(m_machines->get_num_elements()>0);
00397 for (int32_t i=0; i<num_vectors; i++)
00398 output->set_label(i, apply(i));
00399 }
00400 #ifdef USE_SVMLIGHT
00401 else if (scatter_type == NO_BIAS_SVMLIGHT)
00402 {
00403 float64_t* outputs=SG_MALLOC(float64_t, num_vectors*m_num_classes);
00404 SGVector<float64_t>::fill_vector(outputs,num_vectors*m_num_classes,0.0);
00405
00406 for (int32_t i=0; i<num_vectors; i++)
00407 {
00408 for (int32_t j=0; j<svm_proto()->get_num_support_vectors(); j++)
00409 {
00410 float64_t score=m_kernel->kernel(svm_proto()->get_support_vector(j), i)*svm_proto()->get_alpha(j);
00411 int32_t label=((CMulticlassLabels*) m_labels)->get_int_label(svm_proto()->get_support_vector(j));
00412 for (int32_t c=0; c<m_num_classes; c++)
00413 {
00414 float64_t s= (label==c) ? (m_num_classes-1) : (-1);
00415 outputs[c+i*m_num_classes]+=s*score;
00416 }
00417 }
00418 }
00419
00420 for (int32_t i=0; i<num_vectors; i++)
00421 {
00422 int32_t winner=0;
00423 float64_t max_out=outputs[i*m_num_classes+0];
00424
00425 for (int32_t j=1; j<m_num_classes; j++)
00426 {
00427 float64_t out=outputs[i*m_num_classes+j];
00428
00429 if (out>max_out)
00430 {
00431 winner=j;
00432 max_out=out;
00433 }
00434 }
00435
00436 output->set_label(i, winner);
00437 }
00438
00439 SG_FREE(outputs);
00440 }
00441 #endif //USE_SVMLIGHT
00442 else
00443 {
00444 ASSERT(m_machines->get_num_elements()>0);
00445 ASSERT(num_vectors==output->get_num_labels());
00446 CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements());
00447
00448 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00449 {
00450
00451 CSVM *svm = get_svm(i);
00452 ASSERT(svm);
00453 svm->set_kernel(m_kernel);
00454 svm->set_labels(m_labels);
00455 outputs[i]=svm->apply();
00456 SG_UNREF(svm);
00457 }
00458
00459 for (int32_t i=0; i<num_vectors; i++)
00460 {
00461 int32_t winner=0;
00462 float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0];
00463
00464 for (int32_t j=1; j<m_machines->get_num_elements(); j++)
00465 {
00466 float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j];
00467
00468 if (out>max_out)
00469 {
00470 winner=j;
00471 max_out=out;
00472 }
00473 }
00474
00475 output->set_label(i, winner);
00476 }
00477
00478 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00479 SG_UNREF(outputs[i]);
00480
00481 SG_FREE(outputs);
00482 }
00483
00484 return output;
00485 }
00486
00487 float64_t CScatterSVM::apply(int32_t num)
00488 {
00489 ASSERT(m_machines->get_num_elements()>0);
00490 float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements());
00491 int32_t winner=0;
00492
00493 if (scatter_type == TEST_RULE1)
00494 {
00495 for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00496 outputs[c]=get_svm(c)->get_bias()-rho;
00497
00498 for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00499 {
00500 float64_t v=0;
00501
00502 for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++)
00503 {
00504 float64_t alpha=get_svm(c)->get_alpha(i);
00505 int32_t svidx=get_svm(c)->get_support_vector(i);
00506 v += alpha*m_kernel->kernel(svidx, num);
00507 }
00508
00509 outputs[c] += v;
00510 for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00511 outputs[j] -= v/m_machines->get_num_elements();
00512 }
00513
00514 for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00515 outputs[j]/=norm_wcw[j];
00516
00517 float64_t max_out=outputs[0];
00518 for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00519 {
00520 if (outputs[j]>max_out)
00521 {
00522 max_out=outputs[j];
00523 winner=j;
00524 }
00525 }
00526 }
00527 #ifdef USE_SVMLIGHT
00528 else if (scatter_type == NO_BIAS_SVMLIGHT)
00529 {
00530 SG_ERROR("Use classify...\n");
00531 }
00532 #endif //USE_SVMLIGHT
00533 else
00534 {
00535 float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0];
00536
00537 for (int32_t i=1; i<m_machines->get_num_elements(); i++)
00538 {
00539 outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i];
00540 if (outputs[i]>max_out)
00541 {
00542 winner=i;
00543 max_out=outputs[i];
00544 }
00545 }
00546 }
00547
00548 SG_FREE(outputs);
00549 return winner;
00550 }