00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014
00015 using namespace shogun;
00016
00017 CMultiClassSVM::CMultiClassSVM(void)
00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL)
00019 {
00020 SG_UNSTABLE("CMultiClassSVM::CMultiClassSVM(void)", "\n");
00021 init();
00022 }
00023
00024 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00025 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00026 {
00027 init();
00028 }
00029
00030 CMultiClassSVM::CMultiClassSVM(
00031 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00032 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00033 {
00034 init();
00035 }
00036
00037 CMultiClassSVM::~CMultiClassSVM()
00038 {
00039 cleanup();
00040 }
00041
00042 void
00043 CMultiClassSVM::init(void)
00044 {
00045 m_parameters->add((machine_int_t*) &multiclass_type,
00046 "multiclass_type", "Type of MultiClassSVM.");
00047 m_parameters->add(&m_num_classes, "m_num_classes",
00048 "Number of classes.");
00049 m_parameters->add_vector((CSGObject***) &m_svms,
00050 &m_num_svms, "m_svms");
00051 }
00052
00053 void CMultiClassSVM::cleanup()
00054 {
00055 for (int32_t i=0; i<m_num_svms; i++)
00056 SG_UNREF(m_svms[i]);
00057
00058 delete[] m_svms;
00059 m_num_svms=0;
00060 m_svms=NULL;
00061 }
00062
00063 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00064 {
00065 if (num_classes>0)
00066 {
00067 cleanup();
00068
00069 m_num_classes=num_classes;
00070
00071 if (multiclass_type==ONE_VS_REST)
00072 m_num_svms=num_classes;
00073 else if (multiclass_type==ONE_VS_ONE)
00074 m_num_svms=num_classes*(num_classes-1)/2;
00075 else
00076 SG_ERROR("unknown multiclass type\n");
00077
00078 m_svms=new CSVM*[m_num_svms];
00079 if (m_svms)
00080 {
00081 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00082 return true;
00083 }
00084 }
00085 return false;
00086 }
00087
00088 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00089 {
00090 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00091 {
00092 SG_REF(svm);
00093 m_svms[num]=svm;
00094 return true;
00095 }
00096 return false;
00097 }
00098
00099 CLabels* CMultiClassSVM::classify()
00100 {
00101 if (multiclass_type==ONE_VS_REST)
00102 return classify_one_vs_rest();
00103 else if (multiclass_type==ONE_VS_ONE)
00104 return classify_one_vs_one();
00105 else
00106 SG_ERROR("unknown multiclass type\n");
00107
00108 return NULL;
00109 }
00110
00111 CLabels* CMultiClassSVM::classify_one_vs_one()
00112 {
00113 ASSERT(m_num_svms>0);
00114 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00115 CLabels* result=NULL;
00116
00117 if (!kernel)
00118 {
00119 SG_ERROR( "SVM can not proceed without kernel!\n");
00120 return false ;
00121 }
00122
00123 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00124 {
00125 int32_t num_vectors=kernel->get_num_vec_rhs();
00126
00127 result=new CLabels(num_vectors);
00128 SG_REF(result);
00129
00130 ASSERT(num_vectors==result->get_num_labels());
00131 CLabels** outputs=new CLabels*[m_num_svms];
00132
00133 for (int32_t i=0; i<m_num_svms; i++)
00134 {
00135 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00136 ASSERT(m_svms[i]);
00137 m_svms[i]->set_kernel(kernel);
00138 outputs[i]=m_svms[i]->classify();
00139 }
00140
00141 int32_t* votes=new int32_t[m_num_classes];
00142 for (int32_t v=0; v<num_vectors; v++)
00143 {
00144 int32_t s=0;
00145 memset(votes, 0, sizeof(int32_t)*m_num_classes);
00146
00147 for (int32_t i=0; i<m_num_classes; i++)
00148 {
00149 for (int32_t j=i+1; j<m_num_classes; j++)
00150 {
00151 if (outputs[s++]->get_label(v)>0)
00152 votes[i]++;
00153 else
00154 votes[j]++;
00155 }
00156 }
00157
00158 int32_t winner=0;
00159 int32_t max_votes=votes[0];
00160
00161 for (int32_t i=1; i<m_num_classes; i++)
00162 {
00163 if (votes[i]>max_votes)
00164 {
00165 max_votes=votes[i];
00166 winner=i;
00167 }
00168 }
00169
00170 result->set_label(v, winner);
00171 }
00172
00173 delete[] votes;
00174
00175 for (int32_t i=0; i<m_num_svms; i++)
00176 SG_UNREF(outputs[i]);
00177 delete[] outputs;
00178 }
00179
00180 return result;
00181 }
00182
00183 CLabels* CMultiClassSVM::classify_one_vs_rest()
00184 {
00185 ASSERT(m_num_svms>0);
00186 CLabels* result=NULL;
00187
00188 if (!kernel)
00189 {
00190 SG_ERROR( "SVM can not proceed without kernel!\n");
00191 return false ;
00192 }
00193
00194 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00195 {
00196 int32_t num_vectors=kernel->get_num_vec_rhs();
00197
00198 result=new CLabels(num_vectors);
00199 SG_REF(result);
00200
00201 ASSERT(num_vectors==result->get_num_labels());
00202 CLabels** outputs=new CLabels*[m_num_svms];
00203
00204 for (int32_t i=0; i<m_num_svms; i++)
00205 {
00206 ASSERT(m_svms[i]);
00207 m_svms[i]->set_kernel(kernel);
00208 outputs[i]=m_svms[i]->classify();
00209 }
00210
00211 for (int32_t i=0; i<num_vectors; i++)
00212 {
00213 int32_t winner=0;
00214 float64_t max_out=outputs[0]->get_label(i);
00215
00216 for (int32_t j=1; j<m_num_svms; j++)
00217 {
00218 float64_t out=outputs[j]->get_label(i);
00219
00220 if (out>max_out)
00221 {
00222 winner=j;
00223 max_out=out;
00224 }
00225 }
00226
00227 result->set_label(i, winner);
00228 }
00229
00230 for (int32_t i=0; i<m_num_svms; i++)
00231 SG_UNREF(outputs[i]);
00232
00233 delete[] outputs;
00234 }
00235
00236 return result;
00237 }
00238
00239 float64_t CMultiClassSVM::classify_example(int32_t num)
00240 {
00241 if (multiclass_type==ONE_VS_REST)
00242 return classify_example_one_vs_rest(num);
00243 else if (multiclass_type==ONE_VS_ONE)
00244 return classify_example_one_vs_one(num);
00245 else
00246 SG_ERROR("unknown multiclass type\n");
00247
00248 return 0;
00249 }
00250
00251 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00252 {
00253 ASSERT(m_num_svms>0);
00254 float64_t* outputs=new float64_t[m_num_svms];
00255 int32_t winner=0;
00256 float64_t max_out=m_svms[0]->classify_example(num);
00257
00258 for (int32_t i=1; i<m_num_svms; i++)
00259 {
00260 outputs[i]=m_svms[i]->classify_example(num);
00261 if (outputs[i]>max_out)
00262 {
00263 winner=i;
00264 max_out=outputs[i];
00265 }
00266 }
00267 delete[] outputs;
00268
00269 return winner;
00270 }
00271
00272 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00273 {
00274 ASSERT(m_num_svms>0);
00275 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00276
00277 int32_t* votes=new int32_t[m_num_classes];
00278 int32_t s=0;
00279
00280 for (int32_t i=0; i<m_num_classes; i++)
00281 {
00282 for (int32_t j=i+1; j<m_num_classes; j++)
00283 {
00284 if (m_svms[s++]->classify_example(num)>0)
00285 votes[i]++;
00286 else
00287 votes[j]++;
00288 }
00289 }
00290
00291 int32_t winner=0;
00292 int32_t max_votes=votes[0];
00293
00294 for (int32_t i=1; i<m_num_classes; i++)
00295 {
00296 if (votes[i]>max_votes)
00297 {
00298 max_votes=votes[i];
00299 winner=i;
00300 }
00301 }
00302
00303 delete[] votes;
00304
00305 return winner;
00306 }
00307
00308 bool CMultiClassSVM::load(FILE* modelfl)
00309 {
00310 bool result=true;
00311 char char_buffer[1024];
00312 int32_t int_buffer;
00313 float64_t double_buffer;
00314 int32_t line_number=1;
00315 int32_t svm_idx=-1;
00316
00317 SG_SET_LOCALE_C;
00318
00319 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00320 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00321 else
00322 {
00323 char_buffer[15]='\0';
00324 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00325 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00326
00327 line_number++;
00328 }
00329
00330 int_buffer=0;
00331 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00332 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00333
00334 if (!feof(modelfl))
00335 line_number++;
00336
00337 if (int_buffer != multiclass_type)
00338 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00339
00340 int_buffer=0;
00341 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00342 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00343
00344 if (!feof(modelfl))
00345 line_number++;
00346
00347 if (int_buffer < 2)
00348 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00349
00350 create_multiclass_svm(int_buffer);
00351
00352 int_buffer=0;
00353 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00354 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355
00356 if (!feof(modelfl))
00357 line_number++;
00358
00359 if (m_num_svms != int_buffer)
00360 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00361
00362 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00363 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00364
00365 if (!feof(modelfl))
00366 line_number++;
00367
00368 for (int32_t n=0; n<m_num_svms; n++)
00369 {
00370 svm_idx=-1;
00371 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00372 {
00373 result=false;
00374 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00375 }
00376 else
00377 {
00378 char_buffer[4]='\0';
00379 if (strncmp("%SVM", char_buffer, 4)!=0)
00380 {
00381 result=false;
00382 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00383 }
00384
00385 if (svm_idx != n)
00386 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00387
00388 line_number++;
00389 }
00390
00391 int_buffer=0;
00392 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00393 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00394
00395 if (svm_idx != n)
00396 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00397
00398 if (!feof(modelfl))
00399 line_number++;
00400
00401 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00402 CSVM* svm=new CSVM(int_buffer);
00403
00404 double_buffer=0;
00405
00406 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00407 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00408
00409 if (svm_idx != n)
00410 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00411
00412 if (!feof(modelfl))
00413 line_number++;
00414
00415 svm->set_bias(double_buffer);
00416
00417 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00418 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00419
00420 if (svm_idx != n)
00421 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00422
00423 if (!feof(modelfl))
00424 line_number++;
00425
00426 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00427 {
00428 double_buffer=0;
00429 int_buffer=0;
00430
00431 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00432 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00433
00434 if (!feof(modelfl))
00435 line_number++;
00436
00437 svm->set_support_vector(i, int_buffer);
00438 svm->set_alpha(i, double_buffer);
00439 }
00440
00441 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00442 {
00443 result=false;
00444 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00445 }
00446 else
00447 {
00448 char_buffer[3]='\0';
00449 if (strcmp("];", char_buffer)!=0)
00450 {
00451 result=false;
00452 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00453 }
00454 line_number++;
00455 }
00456
00457 set_svm(n, svm);
00458 }
00459
00460 svm_loaded=result;
00461
00462 SG_RESET_LOCALE;
00463 return result;
00464 }
00465
00466 bool CMultiClassSVM::save(FILE* modelfl)
00467 {
00468 SG_SET_LOCALE_C;
00469
00470 if (!kernel)
00471 SG_ERROR("Kernel not defined!\n");
00472
00473 if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00474 SG_ERROR("Multiclass SVM not trained!\n");
00475
00476 SG_INFO( "Writing model file...");
00477 fprintf(modelfl,"%%MultiClassSVM\n");
00478 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00479 fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00480 fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00481 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00482
00483 for (int32_t i=0; i<m_num_svms; i++)
00484 {
00485 CSVM* svm=m_svms[i];
00486 ASSERT(svm);
00487 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00488 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00489 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00490
00491 fprintf(modelfl, "alphas%d=[\n", i);
00492
00493 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00494 {
00495 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00496 svm->get_alpha(j), svm->get_support_vector(j));
00497 }
00498
00499 fprintf(modelfl, "];\n");
00500 }
00501
00502 SG_RESET_LOCALE;
00503 SG_DONE();
00504 return true ;
00505 }