MulticlassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/multiclass/MulticlassSVM.h>
00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassSVM::CMulticlassSVM()
00019     :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL), m_C(0)
00020 {
00021     init();
00022 }
00023 
00024 CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy)
00025     :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL), m_C(0)
00026 {
00027     init();
00028 }
00029 
00030 CMulticlassSVM::CMulticlassSVM(
00031     CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
00032     : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab), m_C(C)
00033 {
00034     init();
00035 }
00036 
00037 CMulticlassSVM::~CMulticlassSVM()
00038 {
00039 }
00040 
00041 void CMulticlassSVM::init()
00042 {
00043     SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
00044 }
00045 
00046 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
00047 {
00048     if (num_classes>0)
00049     {
00050         int32_t num_svms=m_multiclass_strategy->get_num_machines();
00051 
00052         m_machines->reset_array();
00053         for (index_t i=0; i<num_svms; ++i)
00054             m_machines->push_back(NULL);
00055 
00056         return true;
00057     }
00058     return false;
00059 }
00060 
00061 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm)
00062 {
00063     if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm)
00064     {
00065         m_machines->set_element(svm, num);
00066         return true;
00067     }
00068     return false;
00069 }
00070 
00071 bool CMulticlassSVM::init_machines_for_apply(CFeatures* data)
00072 {
00073     if (is_data_locked())
00074     {
00075         SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when "
00076                 "data_lock was called before. Call data_unlock to allow.");
00077     }
00078 
00079     if (!m_kernel)
00080         SG_ERROR("No kernel assigned!\n");
00081 
00082     CFeatures* lhs=m_kernel->get_lhs();
00083     if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED)
00084         SG_ERROR("%s: No left hand side specified\n", get_name());
00085 
00086     if (m_kernel->get_kernel_type()!=K_COMBINED && !lhs->get_num_vectors())
00087     {
00088         SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to"
00089                 " an implementation error in %s, where it was forgotten to set "
00090                 "the data (m_svs) indices\n", get_name(),
00091                 data->get_name());
00092     }
00093 
00094     if (data && m_kernel->get_kernel_type()!=K_COMBINED)
00095         m_kernel->init(lhs, data);
00096     SG_UNREF(lhs);
00097 
00098     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00099     {
00100         CSVM *the_svm = (CSVM *)m_machines->get_element(i);
00101         ASSERT(the_svm);
00102         the_svm->set_kernel(m_kernel);
00103         SG_UNREF(the_svm);
00104     }
00105 
00106     return true;
00107 }
00108 
00109 bool CMulticlassSVM::load(FILE* modelfl)
00110 {
00111     bool result=true;
00112     char char_buffer[1024];
00113     int32_t int_buffer;
00114     float64_t double_buffer;
00115     int32_t line_number=1;
00116     int32_t svm_idx=-1;
00117 
00118     SG_SET_LOCALE_C;
00119 
00120     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00121         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00122     else
00123     {
00124         char_buffer[15]='\0';
00125         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00126             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00127 
00128         line_number++;
00129     }
00130 
00131     int_buffer=0;
00132     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00133         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00134 
00135     if (!feof(modelfl))
00136         line_number++;
00137 
00138     if (int_buffer < 2)
00139         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00140 
00141     create_multiclass_svm(int_buffer);
00142 
00143     int_buffer=0;
00144     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00145         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00146 
00147     if (!feof(modelfl))
00148         line_number++;
00149 
00150     if (m_machines->get_num_elements() != int_buffer)
00151         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer);
00152 
00153     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00154         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00155 
00156     if (!feof(modelfl))
00157         line_number++;
00158 
00159     for (int32_t n=0; n<m_machines->get_num_elements(); n++)
00160     {
00161         svm_idx=-1;
00162         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00163         {
00164             result=false;
00165             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00166         }
00167         else
00168         {
00169             char_buffer[4]='\0';
00170             if (strncmp("%SVM", char_buffer, 4)!=0)
00171             {
00172                 result=false;
00173                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00174             }
00175 
00176             if (svm_idx != n)
00177                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00178 
00179             line_number++;
00180         }
00181 
00182         int_buffer=0;
00183         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00184             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00185 
00186         if (svm_idx != n)
00187             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00188 
00189         if (!feof(modelfl))
00190             line_number++;
00191 
00192         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00193         CSVM* svm=new CSVM(int_buffer);
00194 
00195         double_buffer=0;
00196 
00197         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00198             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00199 
00200         if (svm_idx != n)
00201             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00202 
00203         if (!feof(modelfl))
00204             line_number++;
00205 
00206         svm->set_bias(double_buffer);
00207 
00208         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00209             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00210 
00211         if (svm_idx != n)
00212             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00213 
00214         if (!feof(modelfl))
00215             line_number++;
00216 
00217         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00218         {
00219             double_buffer=0;
00220             int_buffer=0;
00221 
00222             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00223                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00224 
00225             if (!feof(modelfl))
00226                 line_number++;
00227 
00228             svm->set_support_vector(i, int_buffer);
00229             svm->set_alpha(i, double_buffer);
00230         }
00231 
00232         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00233         {
00234             result=false;
00235             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00236         }
00237         else
00238         {
00239             char_buffer[3]='\0';
00240             if (strcmp("];", char_buffer)!=0)
00241             {
00242                 result=false;
00243                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00244             }
00245             line_number++;
00246         }
00247 
00248         set_svm(n, svm);
00249     }
00250 
00251     svm_proto()->svm_loaded=result;
00252 
00253     SG_RESET_LOCALE;
00254     return result;
00255 }
00256 
00257 bool CMulticlassSVM::save(FILE* modelfl)
00258 {
00259     SG_SET_LOCALE_C;
00260 
00261     if (!m_kernel)
00262         SG_ERROR("Kernel not defined!\n");
00263 
00264     if (m_machines->get_num_elements()<1)
00265         SG_ERROR("Multiclass SVM not trained!\n");
00266 
00267     SG_INFO( "Writing model file...");
00268     fprintf(modelfl,"%%MultiClassSVM\n");
00269     fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
00270     fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
00271     fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
00272 
00273     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00274     {
00275         CSVM* svm=get_svm(i);
00276         ASSERT(svm);
00277         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1);
00278         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00279         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00280 
00281         fprintf(modelfl, "alphas%d=[\n", i);
00282 
00283         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00284         {
00285             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00286                     svm->get_alpha(j), svm->get_support_vector(j));
00287         }
00288 
00289         fprintf(modelfl, "];\n");
00290     }
00291 
00292     SG_RESET_LOCALE;
00293     SG_DONE();
00294     return true ;
00295 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation