MulticlassLibSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/multiclass/MulticlassLibSVM.h>
00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/io/SGIO.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassLibSVM::CMulticlassLibSVM(LIBSVM_SOLVER_TYPE st)
00019 : CMulticlassSVM(new CMulticlassOneVsOneStrategy()), model(NULL), solver_type(st)
00020 {
00021 }
00022 
00023 CMulticlassLibSVM::CMulticlassLibSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CMulticlassSVM(new CMulticlassOneVsOneStrategy(), C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00025 {
00026 }
00027 
00028 CMulticlassLibSVM::~CMulticlassLibSVM()
00029 {
00030 }
00031 
00032 bool CMulticlassLibSVM::train_machine(CFeatures* data)
00033 {
00034     struct svm_node* x_space;
00035 
00036     problem = svm_problem();
00037 
00038     ASSERT(m_labels && m_labels->get_num_labels());
00039     ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00040     int32_t num_classes = m_multiclass_strategy->get_num_classes();
00041     problem.l=m_labels->get_num_labels();
00042     SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00043 
00044 
00045     if (data)
00046     {
00047         if (m_labels->get_num_labels() != data->get_num_vectors())
00048         {
00049             SG_ERROR("Number of training vectors does not match number of "
00050                     "labels\n");
00051         }
00052         m_kernel->init(data, data);
00053     }
00054 
00055     problem.y=SG_MALLOC(float64_t, problem.l);
00056     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00057     problem.pv=SG_MALLOC(float64_t, problem.l);
00058     problem.C=SG_MALLOC(float64_t, problem.l);
00059 
00060     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00061 
00062     for (int32_t i=0; i<problem.l; i++)
00063     {
00064         problem.pv[i]=-1.0;
00065         problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00066         problem.x[i]=&x_space[2*i];
00067         x_space[2*i].index=i;
00068         x_space[2*i+1].index=-1;
00069     }
00070 
00071     ASSERT(m_kernel);
00072 
00073     param.svm_type=solver_type; // C SVM or NU_SVM
00074     param.kernel_type = LINEAR;
00075     param.degree = 3;
00076     param.gamma = 0;    // 1/k
00077     param.coef0 = 0;
00078     param.nu = get_nu(); // Nu
00079     param.kernel=m_kernel;
00080     param.cache_size = m_kernel->get_cache_size();
00081     param.max_train_time = m_max_train_time;
00082     param.C = get_C();
00083     param.eps = get_epsilon();
00084     param.p = 0.1;
00085     param.shrinking = 1;
00086     param.nr_weight = 0;
00087     param.weight_label = NULL;
00088     param.weight = NULL;
00089     param.use_bias = svm_proto()->get_bias_enabled();
00090 
00091     const char* error_msg = svm_check_parameter(&problem,&param);
00092 
00093     if(error_msg)
00094         SG_ERROR("Error: %s\n",error_msg);
00095 
00096     model = svm_train(&problem, &param);
00097 
00098     if (model)
00099     {
00100         if (model->nr_class!=num_classes)
00101         {
00102             SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n",
00103                     model->nr_class, num_classes);
00104         }
00105         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00106         create_multiclass_svm(num_classes);
00107 
00108         int32_t* offsets=SG_MALLOC(int32_t, num_classes);
00109         offsets[0]=0;
00110 
00111         for (int32_t i=1; i<num_classes; i++)
00112             offsets[i] = offsets[i-1]+model->nSV[i-1];
00113 
00114         int32_t s=0;
00115         for (int32_t i=0; i<num_classes; i++)
00116         {
00117             for (int32_t j=i+1; j<num_classes; j++)
00118             {
00119                 int32_t k, l;
00120 
00121                 float64_t sgn=1;
00122                 if (model->label[i]>model->label[j])
00123                     sgn=-1;
00124 
00125                 int32_t num_sv=model->nSV[i]+model->nSV[j];
00126                 float64_t bias=-model->rho[s];
00127 
00128                 ASSERT(num_sv>0);
00129                 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00130 
00131                 CSVM* svm=new CSVM(num_sv);
00132 
00133                 svm->set_bias(sgn*bias);
00134 
00135                 int32_t sv_idx=0;
00136                 for (k=0; k<model->nSV[i]; k++)
00137                 {
00138                     SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00139                             model->SV[offsets[i]+k]->index);
00140                     svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00141                     svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00142                     sv_idx++;
00143                 }
00144 
00145                 for (k=0; k<model->nSV[j]; k++)
00146                 {
00147                     SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00148                             model->SV[offsets[i]+k]->index);
00149                     svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00150                     svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00151                     sv_idx++;
00152                 }
00153 
00154                 int32_t idx=0;
00155 
00156                 if (num_classes > 3)
00157                 {
00158                     if (sgn>0)
00159                     {
00160                         for (k=0; k<model->label[i]; k++)
00161                             idx+=num_classes-k-1;
00162     
00163                         for (l=model->label[i]+1; l<model->label[j]; l++)
00164                             idx++;
00165                     }
00166                     else
00167                     {
00168                         for (k=0; k<model->label[j]; k++)
00169                             idx+=num_classes-k-1;
00170 
00171                         for (l=model->label[j]+1; l<model->label[i]; l++)
00172                             idx++;
00173                     }
00174                 }
00175                 else if (num_classes == 3)
00176                 {
00177                     idx = model->label[j]+model->label[i] - 3;
00178                 }
00179                 else if (num_classes == 2)
00180                 {
00181                     idx = i;
00182                 }
00183 //
00184 //              if (sgn>0)
00185 //                  idx=((num_classes-1)*model->label[i]+model->label[j])/2;
00186 //              else
00187 //                  idx=((num_classes-1)*model->label[j]+model->label[i])/2;
00188 //
00189                 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f "
00190                         "label:(%d,%d) -> svm[%d]\n",
00191                         s, num_sv, model->l, bias, model->label[i],
00192                         model->label[j], idx);
00193 
00194                 REQUIRE(set_svm(idx, svm),"SVM set failed");
00195                 s++;
00196             }
00197         }
00198 
00199         set_objective(model->objective);
00200 
00201         SG_FREE(offsets);
00202         SG_FREE(problem.x);
00203         SG_FREE(problem.y);
00204         SG_FREE(x_space);
00205         SG_FREE(problem.pv);
00206         SG_FREE(problem.C);
00207 
00208         svm_destroy_model(model);
00209         model=NULL;
00210 
00211         return true;
00212     }
00213     else
00214         return false;
00215 }
00216 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation