LibSVMMultiClass.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/LibSVMMultiClass.h>
00012 #include <shogun/io/SGIO.h>
00013 
00014 using namespace shogun;
00015 
00016 CLibSVMMultiClass::CLibSVMMultiClass(LIBSVM_SOLVER_TYPE st)
00017 : CMultiClassSVM(ONE_VS_ONE), model(NULL), solver_type(st)
00018 {
00019 }
00020 
00021 CLibSVMMultiClass::CLibSVMMultiClass(float64_t C, CKernel* k, CLabels* lab)
00022 : CMultiClassSVM(ONE_VS_ONE, C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00023 {
00024 }
00025 
00026 CLibSVMMultiClass::~CLibSVMMultiClass()
00027 {
00028     //SG_PRINT("deleting LibSVM\n");
00029 }
00030 
00031 bool CLibSVMMultiClass::train_machine(CFeatures* data)
00032 {
00033     struct svm_node* x_space;
00034 
00035     problem = svm_problem();
00036 
00037     ASSERT(labels && labels->get_num_labels());
00038     int32_t num_classes = labels->get_num_classes();
00039     problem.l=labels->get_num_labels();
00040     SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00041 
00042     if (data)
00043     {
00044         if (labels->get_num_labels() != data->get_num_vectors())
00045             SG_ERROR("Number of training vectors does not match number of labels\n");
00046         kernel->init(data, data);
00047     }
00048 
00049     problem.y=SG_MALLOC(float64_t, problem.l);
00050     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00051     problem.pv=SG_MALLOC(float64_t, problem.l);
00052     problem.C=SG_MALLOC(float64_t, problem.l);
00053 
00054     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00055 
00056     for (int32_t i=0; i<problem.l; i++)
00057     {
00058         problem.pv[i]=-1.0;
00059         problem.y[i]=labels->get_label(i);
00060         problem.x[i]=&x_space[2*i];
00061         x_space[2*i].index=i;
00062         x_space[2*i+1].index=-1;
00063     }
00064 
00065     ASSERT(kernel);
00066 
00067     param.svm_type=solver_type; // C SVM or NU_SVM
00068     param.kernel_type = LINEAR;
00069     param.degree = 3;
00070     param.gamma = 0;    // 1/k
00071     param.coef0 = 0;
00072     param.nu = get_nu(); // Nu
00073     param.kernel=kernel;
00074     param.cache_size = kernel->get_cache_size();
00075     param.max_train_time = max_train_time;
00076     param.C = get_C1();
00077     param.eps = epsilon;
00078     param.p = 0.1;
00079     param.shrinking = 1;
00080     param.nr_weight = 0;
00081     param.weight_label = NULL;
00082     param.weight = NULL;
00083     param.use_bias = get_bias_enabled();
00084 
00085     const char* error_msg = svm_check_parameter(&problem,&param);
00086 
00087     if(error_msg)
00088         SG_ERROR("Error: %s\n",error_msg);
00089 
00090     model = svm_train(&problem, &param);
00091 
00092     if (model)
00093     {
00094         if (model->nr_class!=num_classes)
00095         {
00096             SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n",
00097                     model->nr_class, num_classes);
00098         }
00099         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00100         create_multiclass_svm(num_classes);
00101 
00102         int32_t* offsets=SG_MALLOC(int32_t, num_classes);
00103         offsets[0]=0;
00104 
00105         for (int32_t i=1; i<num_classes; i++)
00106             offsets[i] = offsets[i-1]+model->nSV[i-1];
00107 
00108         int32_t s=0;
00109         for (int32_t i=0; i<num_classes; i++)
00110         {
00111             for (int32_t j=i+1; j<num_classes; j++)
00112             {
00113                 int32_t k, l;
00114 
00115                 float64_t sgn=1;
00116                 if (model->label[i]>model->label[j])
00117                     sgn=-1;
00118 
00119                 int32_t num_sv=model->nSV[i]+model->nSV[j];
00120                 float64_t bias=-model->rho[s];
00121 
00122                 ASSERT(num_sv>0);
00123                 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00124 
00125                 CSVM* svm=new CSVM(num_sv);
00126 
00127                 svm->set_bias(sgn*bias);
00128 
00129                 int32_t sv_idx=0;
00130                 for (k=0; k<model->nSV[i]; k++)
00131                 {
00132                     svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00133                     svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00134                     sv_idx++;
00135                 }
00136 
00137                 for (k=0; k<model->nSV[j]; k++)
00138                 {
00139                     svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00140                     svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00141                     sv_idx++;
00142                 }
00143 
00144                 int32_t idx=0;
00145 
00146                 if (sgn>0)
00147                 {
00148                     for (k=0; k<model->label[i]; k++)
00149                         idx+=num_classes-k-1;
00150 
00151                     for (l=model->label[i]+1; l<model->label[j]; l++)
00152                         idx++;
00153                 }
00154                 else
00155                 {
00156                     for (k=0; k<model->label[j]; k++)
00157                         idx+=num_classes-k-1;
00158 
00159                     for (l=model->label[j]+1; l<model->label[i]; l++)
00160                         idx++;
00161                 }
00162 
00163 
00164 //              if (sgn>0)
00165 //                  idx=((num_classes-1)*model->label[i]+model->label[j])/2;
00166 //              else
00167 //                  idx=((num_classes-1)*model->label[j]+model->label[i])/2;
00168 //
00169                 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f label:(%d,%d) -> svm[%d]\n", s, num_sv, model->l, bias, model->label[i], model->label[j], idx);
00170 
00171                 set_svm(idx, svm);
00172                 s++;
00173             }
00174         }
00175 
00176         CSVM::set_objective(model->objective);
00177 
00178         SG_FREE(offsets);
00179         SG_FREE(problem.x);
00180         SG_FREE(problem.y);
00181         SG_FREE(x_space);
00182 
00183         svm_destroy_model(model);
00184         model=NULL;
00185 
00186         return true;
00187     }
00188     else
00189         return false;
00190 }
00191 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation