MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014 
00015 using namespace shogun;
00016 
00017 CMultiClassSVM::CMultiClassSVM(void)
00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL)
00019 {
00020     SG_UNSTABLE("CMultiClassSVM::CMultiClassSVM(void)", "\n");
00021     init();
00022 }
00023 
00024 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00025 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00026 {
00027     init();
00028 }
00029 
00030 CMultiClassSVM::CMultiClassSVM(
00031     EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00032 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00033 {
00034     init();
00035 }
00036 
00037 CMultiClassSVM::~CMultiClassSVM()
00038 {
00039     cleanup();
00040 }
00041 
00042 void
00043 CMultiClassSVM::init(void)
00044 {
00045     m_parameters->add((machine_int_t*) &multiclass_type,
00046                       "multiclass_type", "Type of MultiClassSVM.");
00047     m_parameters->add(&m_num_classes, "m_num_classes",
00048                       "Number of classes.");
00049     m_parameters->add_vector((CSGObject***) &m_svms,
00050                              &m_num_svms, "m_svms");
00051 }
00052 
00053 void CMultiClassSVM::cleanup()
00054 {
00055     for (int32_t i=0; i<m_num_svms; i++)
00056         SG_UNREF(m_svms[i]);
00057 
00058     delete[] m_svms;
00059     m_num_svms=0;
00060     m_svms=NULL;
00061 }
00062 
00063 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00064 {
00065     if (num_classes>0)
00066     {
00067         cleanup();
00068 
00069         m_num_classes=num_classes;
00070 
00071         if (multiclass_type==ONE_VS_REST)
00072             m_num_svms=num_classes;
00073         else if (multiclass_type==ONE_VS_ONE)
00074             m_num_svms=num_classes*(num_classes-1)/2;
00075         else
00076             SG_ERROR("unknown multiclass type\n");
00077 
00078         m_svms=new CSVM*[m_num_svms];
00079         if (m_svms)
00080         {
00081             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00082             return true;
00083         }
00084     }
00085     return false;
00086 }
00087 
00088 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00089 {
00090     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00091     {
00092         SG_REF(svm);
00093         m_svms[num]=svm;
00094         return true;
00095     }
00096     return false;
00097 }
00098 
00099 CLabels* CMultiClassSVM::classify()
00100 {
00101     if (multiclass_type==ONE_VS_REST)
00102         return classify_one_vs_rest();
00103     else if (multiclass_type==ONE_VS_ONE)
00104         return classify_one_vs_one();
00105     else
00106         SG_ERROR("unknown multiclass type\n");
00107 
00108     return NULL;
00109 }
00110 
00111 CLabels* CMultiClassSVM::classify_one_vs_one()
00112 {
00113     ASSERT(m_num_svms>0);
00114     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00115     CLabels* result=NULL;
00116 
00117     if (!kernel)
00118     {
00119         SG_ERROR( "SVM can not proceed without kernel!\n");
00120         return false ;
00121     }
00122 
00123     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00124     {
00125         int32_t num_vectors=kernel->get_num_vec_rhs();
00126 
00127         result=new CLabels(num_vectors);
00128         SG_REF(result);
00129 
00130         ASSERT(num_vectors==result->get_num_labels());
00131         CLabels** outputs=new CLabels*[m_num_svms];
00132 
00133         for (int32_t i=0; i<m_num_svms; i++)
00134         {
00135             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00136             ASSERT(m_svms[i]);
00137             m_svms[i]->set_kernel(kernel);
00138             outputs[i]=m_svms[i]->classify();
00139         }
00140 
00141         int32_t* votes=new int32_t[m_num_classes];
00142         for (int32_t v=0; v<num_vectors; v++)
00143         {
00144             int32_t s=0;
00145             memset(votes, 0, sizeof(int32_t)*m_num_classes);
00146 
00147             for (int32_t i=0; i<m_num_classes; i++)
00148             {
00149                 for (int32_t j=i+1; j<m_num_classes; j++)
00150                 {
00151                     if (outputs[s++]->get_label(v)>0)
00152                         votes[i]++;
00153                     else
00154                         votes[j]++;
00155                 }
00156             }
00157 
00158             int32_t winner=0;
00159             int32_t max_votes=votes[0];
00160 
00161             for (int32_t i=1; i<m_num_classes; i++)
00162             {
00163                 if (votes[i]>max_votes)
00164                 {
00165                     max_votes=votes[i];
00166                     winner=i;
00167                 }
00168             }
00169 
00170             result->set_label(v, winner);
00171         }
00172 
00173         delete[] votes;
00174 
00175         for (int32_t i=0; i<m_num_svms; i++)
00176             SG_UNREF(outputs[i]);
00177         delete[] outputs;
00178     }
00179 
00180     return result;
00181 }
00182 
00183 CLabels* CMultiClassSVM::classify_one_vs_rest()
00184 {
00185     ASSERT(m_num_svms>0);
00186     CLabels* result=NULL;
00187 
00188     if (!kernel)
00189     {
00190         SG_ERROR( "SVM can not proceed without kernel!\n");
00191         return false ;
00192     }
00193 
00194     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00195     {
00196         int32_t num_vectors=kernel->get_num_vec_rhs();
00197 
00198         result=new CLabels(num_vectors);
00199         SG_REF(result);
00200 
00201         ASSERT(num_vectors==result->get_num_labels());
00202         CLabels** outputs=new CLabels*[m_num_svms];
00203 
00204         for (int32_t i=0; i<m_num_svms; i++)
00205         {
00206             ASSERT(m_svms[i]);
00207             m_svms[i]->set_kernel(kernel);
00208             outputs[i]=m_svms[i]->classify();
00209         }
00210 
00211         for (int32_t i=0; i<num_vectors; i++)
00212         {
00213             int32_t winner=0;
00214             float64_t max_out=outputs[0]->get_label(i);
00215 
00216             for (int32_t j=1; j<m_num_svms; j++)
00217             {
00218                 float64_t out=outputs[j]->get_label(i);
00219 
00220                 if (out>max_out)
00221                 {
00222                     winner=j;
00223                     max_out=out;
00224                 }
00225             }
00226 
00227             result->set_label(i, winner);
00228         }
00229 
00230         for (int32_t i=0; i<m_num_svms; i++)
00231             SG_UNREF(outputs[i]);
00232 
00233         delete[] outputs;
00234     }
00235 
00236     return result;
00237 }
00238 
00239 float64_t CMultiClassSVM::classify_example(int32_t num)
00240 {
00241     if (multiclass_type==ONE_VS_REST)
00242         return classify_example_one_vs_rest(num);
00243     else if (multiclass_type==ONE_VS_ONE)
00244         return classify_example_one_vs_one(num);
00245     else
00246         SG_ERROR("unknown multiclass type\n");
00247 
00248     return 0;
00249 }
00250 
00251 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00252 {
00253     ASSERT(m_num_svms>0);
00254     float64_t* outputs=new float64_t[m_num_svms];
00255     int32_t winner=0;
00256     float64_t max_out=m_svms[0]->classify_example(num);
00257 
00258     for (int32_t i=1; i<m_num_svms; i++)
00259     {
00260         outputs[i]=m_svms[i]->classify_example(num);
00261         if (outputs[i]>max_out)
00262         {
00263             winner=i;
00264             max_out=outputs[i];
00265         }
00266     }
00267     delete[] outputs;
00268 
00269     return winner;
00270 }
00271 
00272 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00273 {
00274     ASSERT(m_num_svms>0);
00275     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00276 
00277     int32_t* votes=new int32_t[m_num_classes];
00278     int32_t s=0;
00279 
00280     for (int32_t i=0; i<m_num_classes; i++)
00281     {
00282         for (int32_t j=i+1; j<m_num_classes; j++)
00283         {
00284             if (m_svms[s++]->classify_example(num)>0)
00285                 votes[i]++;
00286             else
00287                 votes[j]++;
00288         }
00289     }
00290 
00291     int32_t winner=0;
00292     int32_t max_votes=votes[0];
00293 
00294     for (int32_t i=1; i<m_num_classes; i++)
00295     {
00296         if (votes[i]>max_votes)
00297         {
00298             max_votes=votes[i];
00299             winner=i;
00300         }
00301     }
00302 
00303     delete[] votes;
00304 
00305     return winner;
00306 }
00307 
00308 bool CMultiClassSVM::load(FILE* modelfl)
00309 {
00310     bool result=true;
00311     char char_buffer[1024];
00312     int32_t int_buffer;
00313     float64_t double_buffer;
00314     int32_t line_number=1;
00315     int32_t svm_idx=-1;
00316 
00317     SG_SET_LOCALE_C;
00318 
00319     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00320         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00321     else
00322     {
00323         char_buffer[15]='\0';
00324         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00325             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00326 
00327         line_number++;
00328     }
00329 
00330     int_buffer=0;
00331     if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00332         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00333 
00334     if (!feof(modelfl))
00335         line_number++;
00336 
00337     if (int_buffer != multiclass_type)
00338         SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00339 
00340     int_buffer=0;
00341     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00342         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00343 
00344     if (!feof(modelfl))
00345         line_number++;
00346 
00347     if (int_buffer < 2)
00348         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00349 
00350     create_multiclass_svm(int_buffer);
00351 
00352     int_buffer=0;
00353     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00354         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355 
00356     if (!feof(modelfl))
00357         line_number++;
00358 
00359     if (m_num_svms != int_buffer)
00360         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00361 
00362     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00363         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00364 
00365     if (!feof(modelfl))
00366         line_number++;
00367 
00368     for (int32_t n=0; n<m_num_svms; n++)
00369     {
00370         svm_idx=-1;
00371         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00372         {
00373             result=false;
00374             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00375         }
00376         else
00377         {
00378             char_buffer[4]='\0';
00379             if (strncmp("%SVM", char_buffer, 4)!=0)
00380             {
00381                 result=false;
00382                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00383             }
00384 
00385             if (svm_idx != n)
00386                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00387 
00388             line_number++;
00389         }
00390 
00391         int_buffer=0;
00392         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00393             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00394 
00395         if (svm_idx != n)
00396             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00397 
00398         if (!feof(modelfl))
00399             line_number++;
00400 
00401         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00402         CSVM* svm=new CSVM(int_buffer);
00403 
00404         double_buffer=0;
00405 
00406         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00407             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00408 
00409         if (svm_idx != n)
00410             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00411 
00412         if (!feof(modelfl))
00413             line_number++;
00414 
00415         svm->set_bias(double_buffer);
00416 
00417         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00418             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00419 
00420         if (svm_idx != n)
00421             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00422 
00423         if (!feof(modelfl))
00424             line_number++;
00425 
00426         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00427         {
00428             double_buffer=0;
00429             int_buffer=0;
00430 
00431             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00432                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00433 
00434             if (!feof(modelfl))
00435                 line_number++;
00436 
00437             svm->set_support_vector(i, int_buffer);
00438             svm->set_alpha(i, double_buffer);
00439         }
00440 
00441         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00442         {
00443             result=false;
00444             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00445         }
00446         else
00447         {
00448             char_buffer[3]='\0';
00449             if (strcmp("];", char_buffer)!=0)
00450             {
00451                 result=false;
00452                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00453             }
00454             line_number++;
00455         }
00456 
00457         set_svm(n, svm);
00458     }
00459 
00460     svm_loaded=result;
00461 
00462     SG_RESET_LOCALE;
00463     return result;
00464 }
00465 
00466 bool CMultiClassSVM::save(FILE* modelfl)
00467 {
00468     SG_SET_LOCALE_C;
00469 
00470     if (!kernel)
00471         SG_ERROR("Kernel not defined!\n");
00472 
00473     if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00474         SG_ERROR("Multiclass SVM not trained!\n");
00475 
00476     SG_INFO( "Writing model file...");
00477     fprintf(modelfl,"%%MultiClassSVM\n");
00478     fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00479     fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00480     fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00481     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00482 
00483     for (int32_t i=0; i<m_num_svms; i++)
00484     {
00485         CSVM* svm=m_svms[i];
00486         ASSERT(svm);
00487         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00488         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00489         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00490 
00491         fprintf(modelfl, "alphas%d=[\n", i);
00492 
00493         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00494         {
00495             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00496                     svm->get_alpha(j), svm->get_support_vector(j));
00497         }
00498 
00499         fprintf(modelfl, "];\n");
00500     }
00501 
00502     SG_RESET_LOCALE;
00503     SG_DONE();
00504     return true ;
00505 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation