00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/classifier/svm/MultiClassSVM.h>
00014
00015 using namespace shogun;
00016
00017 CMultiClassSVM::CMultiClassSVM()
00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL)
00019 {
00020 init();
00021 }
00022
00023 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00024 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00025 {
00026 init();
00027 }
00028
00029 CMultiClassSVM::CMultiClassSVM(
00030 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00031 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00032 {
00033 init();
00034 }
00035
00036 CMultiClassSVM::~CMultiClassSVM()
00037 {
00038 cleanup();
00039 }
00040
00041 void CMultiClassSVM::init()
00042 {
00043 m_parameters->add((machine_int_t*) &multiclass_type,
00044 "multiclass_type", "Type of MultiClassSVM.");
00045 m_parameters->add(&m_num_classes, "m_num_classes",
00046 "Number of classes.");
00047 m_parameters->add_vector((CSGObject***) &m_svms,
00048 &m_num_svms, "m_svms");
00049 }
00050
00051 void CMultiClassSVM::cleanup()
00052 {
00053 for (int32_t i=0; i<m_num_svms; i++)
00054 SG_UNREF(m_svms[i]);
00055
00056 SG_FREE(m_svms);
00057 m_num_svms=0;
00058 m_svms=NULL;
00059 }
00060
00061 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00062 {
00063 if (num_classes>0)
00064 {
00065 cleanup();
00066
00067 m_num_classes=num_classes;
00068
00069 if (multiclass_type==ONE_VS_REST)
00070 m_num_svms=num_classes;
00071 else if (multiclass_type==ONE_VS_ONE)
00072 m_num_svms=num_classes*(num_classes-1)/2;
00073 else
00074 SG_ERROR("unknown multiclass type\n");
00075
00076 m_svms=SG_MALLOC(CSVM*, m_num_svms);
00077 if (m_svms)
00078 {
00079 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00080 return true;
00081 }
00082 }
00083 return false;
00084 }
00085
00086 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00087 {
00088 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00089 {
00090 SG_REF(svm);
00091 m_svms[num]=svm;
00092 return true;
00093 }
00094 return false;
00095 }
00096
00097 CLabels* CMultiClassSVM::apply()
00098 {
00099 if (multiclass_type==ONE_VS_REST)
00100 return classify_one_vs_rest();
00101 else if (multiclass_type==ONE_VS_ONE)
00102 return classify_one_vs_one();
00103 else
00104 SG_ERROR("unknown multiclass type\n");
00105
00106 return NULL;
00107 }
00108
00109 CLabels* CMultiClassSVM::classify_one_vs_one()
00110 {
00111 ASSERT(m_num_svms>0);
00112 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00113 CLabels* result=NULL;
00114
00115 if (!kernel)
00116 {
00117 SG_ERROR( "SVM can not proceed without kernel!\n");
00118 return false ;
00119 }
00120
00121 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00122 {
00123 int32_t num_vectors=kernel->get_num_vec_rhs();
00124
00125 result=new CLabels(num_vectors);
00126 SG_REF(result);
00127
00128 ASSERT(num_vectors==result->get_num_labels());
00129 CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
00130
00131 for (int32_t i=0; i<m_num_svms; i++)
00132 {
00133 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00134 ASSERT(m_svms[i]);
00135 m_svms[i]->set_kernel(kernel);
00136 outputs[i]=m_svms[i]->apply();
00137 }
00138
00139 int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
00140 for (int32_t v=0; v<num_vectors; v++)
00141 {
00142 int32_t s=0;
00143 memset(votes, 0, sizeof(int32_t)*m_num_classes);
00144
00145 for (int32_t i=0; i<m_num_classes; i++)
00146 {
00147 for (int32_t j=i+1; j<m_num_classes; j++)
00148 {
00149 if (outputs[s++]->get_label(v)>0)
00150 votes[i]++;
00151 else
00152 votes[j]++;
00153 }
00154 }
00155
00156 int32_t winner=0;
00157 int32_t max_votes=votes[0];
00158
00159 for (int32_t i=1; i<m_num_classes; i++)
00160 {
00161 if (votes[i]>max_votes)
00162 {
00163 max_votes=votes[i];
00164 winner=i;
00165 }
00166 }
00167
00168 result->set_label(v, winner);
00169 }
00170
00171 SG_FREE(votes);
00172
00173 for (int32_t i=0; i<m_num_svms; i++)
00174 SG_UNREF(outputs[i]);
00175 SG_FREE(outputs);
00176 }
00177
00178 return result;
00179 }
00180
00181 CLabels* CMultiClassSVM::classify_one_vs_rest()
00182 {
00183 ASSERT(m_num_svms>0);
00184 CLabels* result=NULL;
00185
00186 if (!kernel)
00187 {
00188 SG_ERROR( "SVM can not proceed without kernel!\n");
00189 return false ;
00190 }
00191
00192 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00193 {
00194 int32_t num_vectors=kernel->get_num_vec_rhs();
00195
00196 result=new CLabels(num_vectors);
00197 SG_REF(result);
00198
00199 ASSERT(num_vectors==result->get_num_labels());
00200 CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
00201
00202 for (int32_t i=0; i<m_num_svms; i++)
00203 {
00204 ASSERT(m_svms[i]);
00205 m_svms[i]->set_kernel(kernel);
00206 outputs[i]=m_svms[i]->apply();
00207 }
00208
00209 for (int32_t i=0; i<num_vectors; i++)
00210 {
00211 int32_t winner=0;
00212 float64_t max_out=outputs[0]->get_label(i);
00213
00214 for (int32_t j=1; j<m_num_svms; j++)
00215 {
00216 float64_t out=outputs[j]->get_label(i);
00217
00218 if (out>max_out)
00219 {
00220 winner=j;
00221 max_out=out;
00222 }
00223 }
00224
00225 result->set_label(i, winner);
00226 }
00227
00228 for (int32_t i=0; i<m_num_svms; i++)
00229 SG_UNREF(outputs[i]);
00230
00231 SG_FREE(outputs);
00232 }
00233
00234 return result;
00235 }
00236
00237 float64_t CMultiClassSVM::apply(int32_t num)
00238 {
00239 if (multiclass_type==ONE_VS_REST)
00240 return classify_example_one_vs_rest(num);
00241 else if (multiclass_type==ONE_VS_ONE)
00242 return classify_example_one_vs_one(num);
00243 else
00244 SG_ERROR("unknown multiclass type\n");
00245
00246 return 0;
00247 }
00248
00249 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00250 {
00251 ASSERT(m_num_svms>0);
00252 float64_t* outputs=SG_MALLOC(float64_t, m_num_svms);
00253 int32_t winner=0;
00254 float64_t max_out=m_svms[0]->apply(num);
00255
00256 for (int32_t i=1; i<m_num_svms; i++)
00257 {
00258 outputs[i]=m_svms[i]->apply(num);
00259 if (outputs[i]>max_out)
00260 {
00261 winner=i;
00262 max_out=outputs[i];
00263 }
00264 }
00265 SG_FREE(outputs);
00266
00267 return winner;
00268 }
00269
00270 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00271 {
00272 ASSERT(m_num_svms>0);
00273 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00274
00275 int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
00276 int32_t s=0;
00277
00278 for (int32_t i=0; i<m_num_classes; i++)
00279 {
00280 for (int32_t j=i+1; j<m_num_classes; j++)
00281 {
00282 if (m_svms[s++]->apply(num)>0)
00283 votes[i]++;
00284 else
00285 votes[j]++;
00286 }
00287 }
00288
00289 int32_t winner=0;
00290 int32_t max_votes=votes[0];
00291
00292 for (int32_t i=1; i<m_num_classes; i++)
00293 {
00294 if (votes[i]>max_votes)
00295 {
00296 max_votes=votes[i];
00297 winner=i;
00298 }
00299 }
00300
00301 SG_FREE(votes);
00302
00303 return winner;
00304 }
00305
00306 bool CMultiClassSVM::load(FILE* modelfl)
00307 {
00308 bool result=true;
00309 char char_buffer[1024];
00310 int32_t int_buffer;
00311 float64_t double_buffer;
00312 int32_t line_number=1;
00313 int32_t svm_idx=-1;
00314
00315 SG_SET_LOCALE_C;
00316
00317 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00318 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00319 else
00320 {
00321 char_buffer[15]='\0';
00322 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00323 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00324
00325 line_number++;
00326 }
00327
00328 int_buffer=0;
00329 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00330 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00331
00332 if (!feof(modelfl))
00333 line_number++;
00334
00335 if (int_buffer != multiclass_type)
00336 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00337
00338 int_buffer=0;
00339 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00340 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00341
00342 if (!feof(modelfl))
00343 line_number++;
00344
00345 if (int_buffer < 2)
00346 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00347
00348 create_multiclass_svm(int_buffer);
00349
00350 int_buffer=0;
00351 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00352 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00353
00354 if (!feof(modelfl))
00355 line_number++;
00356
00357 if (m_num_svms != int_buffer)
00358 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00359
00360 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00361 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00362
00363 if (!feof(modelfl))
00364 line_number++;
00365
00366 for (int32_t n=0; n<m_num_svms; n++)
00367 {
00368 svm_idx=-1;
00369 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00370 {
00371 result=false;
00372 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00373 }
00374 else
00375 {
00376 char_buffer[4]='\0';
00377 if (strncmp("%SVM", char_buffer, 4)!=0)
00378 {
00379 result=false;
00380 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00381 }
00382
00383 if (svm_idx != n)
00384 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00385
00386 line_number++;
00387 }
00388
00389 int_buffer=0;
00390 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00391 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00392
00393 if (svm_idx != n)
00394 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00395
00396 if (!feof(modelfl))
00397 line_number++;
00398
00399 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00400 CSVM* svm=new CSVM(int_buffer);
00401
00402 double_buffer=0;
00403
00404 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00405 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00406
00407 if (svm_idx != n)
00408 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00409
00410 if (!feof(modelfl))
00411 line_number++;
00412
00413 svm->set_bias(double_buffer);
00414
00415 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00416 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00417
00418 if (svm_idx != n)
00419 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00420
00421 if (!feof(modelfl))
00422 line_number++;
00423
00424 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00425 {
00426 double_buffer=0;
00427 int_buffer=0;
00428
00429 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00430 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00431
00432 if (!feof(modelfl))
00433 line_number++;
00434
00435 svm->set_support_vector(i, int_buffer);
00436 svm->set_alpha(i, double_buffer);
00437 }
00438
00439 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00440 {
00441 result=false;
00442 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00443 }
00444 else
00445 {
00446 char_buffer[3]='\0';
00447 if (strcmp("];", char_buffer)!=0)
00448 {
00449 result=false;
00450 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00451 }
00452 line_number++;
00453 }
00454
00455 set_svm(n, svm);
00456 }
00457
00458 svm_loaded=result;
00459
00460 SG_RESET_LOCALE;
00461 return result;
00462 }
00463
00464 bool CMultiClassSVM::save(FILE* modelfl)
00465 {
00466 SG_SET_LOCALE_C;
00467
00468 if (!kernel)
00469 SG_ERROR("Kernel not defined!\n");
00470
00471 if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00472 SG_ERROR("Multiclass SVM not trained!\n");
00473
00474 SG_INFO( "Writing model file...");
00475 fprintf(modelfl,"%%MultiClassSVM\n");
00476 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00477 fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00478 fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00479 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00480
00481 for (int32_t i=0; i<m_num_svms; i++)
00482 {
00483 CSVM* svm=m_svms[i];
00484 ASSERT(svm);
00485 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00486 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00487 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00488
00489 fprintf(modelfl, "alphas%d=[\n", i);
00490
00491 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00492 {
00493 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00494 svm->get_alpha(j), svm->get_support_vector(j));
00495 }
00496
00497 fprintf(modelfl, "];\n");
00498 }
00499
00500 SG_RESET_LOCALE;
00501 SG_DONE();
00502 return true ;
00503 }