MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/classifier/svm/MultiClassSVM.h>
00014 
00015 using namespace shogun;
00016 
00017 CMultiClassSVM::CMultiClassSVM()
00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL)
00019 {
00020     init();
00021 }
00022 
00023 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00024 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00025 {
00026     init();
00027 }
00028 
00029 CMultiClassSVM::CMultiClassSVM(
00030     EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00031 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00032 {
00033     init();
00034 }
00035 
00036 CMultiClassSVM::~CMultiClassSVM()
00037 {
00038     cleanup();
00039 }
00040 
00041 void CMultiClassSVM::init()
00042 {
00043     m_parameters->add((machine_int_t*) &multiclass_type,
00044                       "multiclass_type", "Type of MultiClassSVM.");
00045     m_parameters->add(&m_num_classes, "m_num_classes",
00046                       "Number of classes.");
00047     m_parameters->add_vector((CSGObject***) &m_svms,
00048                              &m_num_svms, "m_svms");
00049 }
00050 
00051 void CMultiClassSVM::cleanup()
00052 {
00053     for (int32_t i=0; i<m_num_svms; i++)
00054         SG_UNREF(m_svms[i]);
00055 
00056     SG_FREE(m_svms);
00057     m_num_svms=0;
00058     m_svms=NULL;
00059 }
00060 
00061 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00062 {
00063     if (num_classes>0)
00064     {
00065         cleanup();
00066 
00067         m_num_classes=num_classes;
00068 
00069         if (multiclass_type==ONE_VS_REST)
00070             m_num_svms=num_classes;
00071         else if (multiclass_type==ONE_VS_ONE)
00072             m_num_svms=num_classes*(num_classes-1)/2;
00073         else
00074             SG_ERROR("unknown multiclass type\n");
00075 
00076         m_svms=SG_MALLOC(CSVM*, m_num_svms);
00077         if (m_svms)
00078         {
00079             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00080             return true;
00081         }
00082     }
00083     return false;
00084 }
00085 
00086 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00087 {
00088     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00089     {
00090         SG_REF(svm);
00091         m_svms[num]=svm;
00092         return true;
00093     }
00094     return false;
00095 }
00096 
00097 CLabels* CMultiClassSVM::apply()
00098 {
00099     if (multiclass_type==ONE_VS_REST)
00100         return classify_one_vs_rest();
00101     else if (multiclass_type==ONE_VS_ONE)
00102         return classify_one_vs_one();
00103     else
00104         SG_ERROR("unknown multiclass type\n");
00105 
00106     return NULL;
00107 }
00108 
00109 CLabels* CMultiClassSVM::classify_one_vs_one()
00110 {
00111     ASSERT(m_num_svms>0);
00112     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00113     CLabels* result=NULL;
00114 
00115     if (!kernel)
00116     {
00117         SG_ERROR( "SVM can not proceed without kernel!\n");
00118         return false ;
00119     }
00120 
00121     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00122     {
00123         int32_t num_vectors=kernel->get_num_vec_rhs();
00124 
00125         result=new CLabels(num_vectors);
00126         SG_REF(result);
00127 
00128         ASSERT(num_vectors==result->get_num_labels());
00129         CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
00130 
00131         for (int32_t i=0; i<m_num_svms; i++)
00132         {
00133             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00134             ASSERT(m_svms[i]);
00135             m_svms[i]->set_kernel(kernel);
00136             outputs[i]=m_svms[i]->apply();
00137         }
00138 
00139         int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
00140         for (int32_t v=0; v<num_vectors; v++)
00141         {
00142             int32_t s=0;
00143             memset(votes, 0, sizeof(int32_t)*m_num_classes);
00144 
00145             for (int32_t i=0; i<m_num_classes; i++)
00146             {
00147                 for (int32_t j=i+1; j<m_num_classes; j++)
00148                 {
00149                     if (outputs[s++]->get_label(v)>0)
00150                         votes[i]++;
00151                     else
00152                         votes[j]++;
00153                 }
00154             }
00155 
00156             int32_t winner=0;
00157             int32_t max_votes=votes[0];
00158 
00159             for (int32_t i=1; i<m_num_classes; i++)
00160             {
00161                 if (votes[i]>max_votes)
00162                 {
00163                     max_votes=votes[i];
00164                     winner=i;
00165                 }
00166             }
00167 
00168             result->set_label(v, winner);
00169         }
00170 
00171         SG_FREE(votes);
00172 
00173         for (int32_t i=0; i<m_num_svms; i++)
00174             SG_UNREF(outputs[i]);
00175         SG_FREE(outputs);
00176     }
00177 
00178     return result;
00179 }
00180 
00181 CLabels* CMultiClassSVM::classify_one_vs_rest()
00182 {
00183     ASSERT(m_num_svms>0);
00184     CLabels* result=NULL;
00185 
00186     if (!kernel)
00187     {
00188         SG_ERROR( "SVM can not proceed without kernel!\n");
00189         return false ;
00190     }
00191 
00192     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00193     {
00194         int32_t num_vectors=kernel->get_num_vec_rhs();
00195 
00196         result=new CLabels(num_vectors);
00197         SG_REF(result);
00198 
00199         ASSERT(num_vectors==result->get_num_labels());
00200         CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
00201 
00202         for (int32_t i=0; i<m_num_svms; i++)
00203         {
00204             ASSERT(m_svms[i]);
00205             m_svms[i]->set_kernel(kernel);
00206             outputs[i]=m_svms[i]->apply();
00207         }
00208 
00209         for (int32_t i=0; i<num_vectors; i++)
00210         {
00211             int32_t winner=0;
00212             float64_t max_out=outputs[0]->get_label(i);
00213 
00214             for (int32_t j=1; j<m_num_svms; j++)
00215             {
00216                 float64_t out=outputs[j]->get_label(i);
00217 
00218                 if (out>max_out)
00219                 {
00220                     winner=j;
00221                     max_out=out;
00222                 }
00223             }
00224 
00225             result->set_label(i, winner);
00226         }
00227 
00228         for (int32_t i=0; i<m_num_svms; i++)
00229             SG_UNREF(outputs[i]);
00230 
00231         SG_FREE(outputs);
00232     }
00233 
00234     return result;
00235 }
00236 
00237 float64_t CMultiClassSVM::apply(int32_t num)
00238 {
00239     if (multiclass_type==ONE_VS_REST)
00240         return classify_example_one_vs_rest(num);
00241     else if (multiclass_type==ONE_VS_ONE)
00242         return classify_example_one_vs_one(num);
00243     else
00244         SG_ERROR("unknown multiclass type\n");
00245 
00246     return 0;
00247 }
00248 
00249 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00250 {
00251     ASSERT(m_num_svms>0);
00252     float64_t* outputs=SG_MALLOC(float64_t, m_num_svms);
00253     int32_t winner=0;
00254     float64_t max_out=m_svms[0]->apply(num);
00255 
00256     for (int32_t i=1; i<m_num_svms; i++)
00257     {
00258         outputs[i]=m_svms[i]->apply(num);
00259         if (outputs[i]>max_out)
00260         {
00261             winner=i;
00262             max_out=outputs[i];
00263         }
00264     }
00265     SG_FREE(outputs);
00266 
00267     return winner;
00268 }
00269 
00270 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00271 {
00272     ASSERT(m_num_svms>0);
00273     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00274 
00275     int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
00276     int32_t s=0;
00277 
00278     for (int32_t i=0; i<m_num_classes; i++)
00279     {
00280         for (int32_t j=i+1; j<m_num_classes; j++)
00281         {
00282             if (m_svms[s++]->apply(num)>0)
00283                 votes[i]++;
00284             else
00285                 votes[j]++;
00286         }
00287     }
00288 
00289     int32_t winner=0;
00290     int32_t max_votes=votes[0];
00291 
00292     for (int32_t i=1; i<m_num_classes; i++)
00293     {
00294         if (votes[i]>max_votes)
00295         {
00296             max_votes=votes[i];
00297             winner=i;
00298         }
00299     }
00300 
00301     SG_FREE(votes);
00302 
00303     return winner;
00304 }
00305 
00306 bool CMultiClassSVM::load(FILE* modelfl)
00307 {
00308     bool result=true;
00309     char char_buffer[1024];
00310     int32_t int_buffer;
00311     float64_t double_buffer;
00312     int32_t line_number=1;
00313     int32_t svm_idx=-1;
00314 
00315     SG_SET_LOCALE_C;
00316 
00317     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00318         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00319     else
00320     {
00321         char_buffer[15]='\0';
00322         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00323             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00324 
00325         line_number++;
00326     }
00327 
00328     int_buffer=0;
00329     if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00330         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00331 
00332     if (!feof(modelfl))
00333         line_number++;
00334 
00335     if (int_buffer != multiclass_type)
00336         SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00337 
00338     int_buffer=0;
00339     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00340         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00341 
00342     if (!feof(modelfl))
00343         line_number++;
00344 
00345     if (int_buffer < 2)
00346         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00347 
00348     create_multiclass_svm(int_buffer);
00349 
00350     int_buffer=0;
00351     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00352         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00353 
00354     if (!feof(modelfl))
00355         line_number++;
00356 
00357     if (m_num_svms != int_buffer)
00358         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00359 
00360     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00361         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00362 
00363     if (!feof(modelfl))
00364         line_number++;
00365 
00366     for (int32_t n=0; n<m_num_svms; n++)
00367     {
00368         svm_idx=-1;
00369         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00370         {
00371             result=false;
00372             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00373         }
00374         else
00375         {
00376             char_buffer[4]='\0';
00377             if (strncmp("%SVM", char_buffer, 4)!=0)
00378             {
00379                 result=false;
00380                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00381             }
00382 
00383             if (svm_idx != n)
00384                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00385 
00386             line_number++;
00387         }
00388 
00389         int_buffer=0;
00390         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00391             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00392 
00393         if (svm_idx != n)
00394             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00395 
00396         if (!feof(modelfl))
00397             line_number++;
00398 
00399         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00400         CSVM* svm=new CSVM(int_buffer);
00401 
00402         double_buffer=0;
00403 
00404         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00405             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00406 
00407         if (svm_idx != n)
00408             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00409 
00410         if (!feof(modelfl))
00411             line_number++;
00412 
00413         svm->set_bias(double_buffer);
00414 
00415         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00416             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00417 
00418         if (svm_idx != n)
00419             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00420 
00421         if (!feof(modelfl))
00422             line_number++;
00423 
00424         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00425         {
00426             double_buffer=0;
00427             int_buffer=0;
00428 
00429             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00430                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00431 
00432             if (!feof(modelfl))
00433                 line_number++;
00434 
00435             svm->set_support_vector(i, int_buffer);
00436             svm->set_alpha(i, double_buffer);
00437         }
00438 
00439         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00440         {
00441             result=false;
00442             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00443         }
00444         else
00445         {
00446             char_buffer[3]='\0';
00447             if (strcmp("];", char_buffer)!=0)
00448             {
00449                 result=false;
00450                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00451             }
00452             line_number++;
00453         }
00454 
00455         set_svm(n, svm);
00456     }
00457 
00458     svm_loaded=result;
00459 
00460     SG_RESET_LOCALE;
00461     return result;
00462 }
00463 
00464 bool CMultiClassSVM::save(FILE* modelfl)
00465 {
00466     SG_SET_LOCALE_C;
00467 
00468     if (!kernel)
00469         SG_ERROR("Kernel not defined!\n");
00470 
00471     if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00472         SG_ERROR("Multiclass SVM not trained!\n");
00473 
00474     SG_INFO( "Writing model file...");
00475     fprintf(modelfl,"%%MultiClassSVM\n");
00476     fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00477     fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00478     fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00479     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00480 
00481     for (int32_t i=0; i<m_num_svms; i++)
00482     {
00483         CSVM* svm=m_svms[i];
00484         ASSERT(svm);
00485         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00486         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00487         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00488 
00489         fprintf(modelfl, "alphas%d=[\n", i);
00490 
00491         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00492         {
00493             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00494                     svm->get_alpha(j), svm->get_support_vector(j));
00495         }
00496 
00497         fprintf(modelfl, "];\n");
00498     }
00499 
00500     SG_RESET_LOCALE;
00501     SG_DONE();
00502     return true ;
00503 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation