LibSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/LibSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 
00014 using namespace shogun;
00015 
00016 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
00017 : CSVM(), model(NULL), solver_type(st)
00018 {
00019 }
00020 
00021 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00023 {
00024     problem = svm_problem();
00025 }
00026 
00027 CLibSVM::~CLibSVM()
00028 {
00029 }
00030 
00031 
00032 bool CLibSVM::train_machine(CFeatures* data)
00033 {
00034     struct svm_node* x_space;
00035 
00036     ASSERT(labels && labels->get_num_labels());
00037     ASSERT(labels->is_two_class_labeling());
00038 
00039     if (data)
00040     {
00041         if (labels->get_num_labels() != data->get_num_vectors())
00042             SG_ERROR("Number of training vectors does not match number of labels\n");
00043         kernel->init(data, data);
00044     }
00045 
00046     problem.l=labels->get_num_labels();
00047     SG_INFO( "%d trainlabels\n", problem.l);
00048 
00049     // set linear term
00050     if (m_linear_term.vlen>0)
00051     {
00052         if (labels->get_num_labels()!=m_linear_term.vlen)
00053             SG_ERROR("Number of training vectors does not match length of linear term\n");
00054 
00055         // set with linear term from base class
00056         problem.pv = get_linear_term_array();
00057     }
00058     else
00059     {
00060         // fill with minus ones
00061         problem.pv = SG_MALLOC(float64_t, problem.l);
00062 
00063         for (int i=0; i!=problem.l; i++)
00064             problem.pv[i] = -1.0;
00065     }
00066 
00067     problem.y=SG_MALLOC(float64_t, problem.l);
00068     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00069     problem.C=SG_MALLOC(float64_t, problem.l);
00070 
00071     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00072 
00073     for (int32_t i=0; i<problem.l; i++)
00074     {
00075         problem.y[i]=labels->get_label(i);
00076         problem.x[i]=&x_space[2*i];
00077         x_space[2*i].index=i;
00078         x_space[2*i+1].index=-1;
00079     }
00080 
00081     int32_t weights_label[2]={-1,+1};
00082     float64_t weights[2]={1.0,get_C2()/get_C1()};
00083 
00084     ASSERT(kernel && kernel->has_features());
00085     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00086 
00087     param.svm_type=solver_type; // C SVM or NU_SVM
00088     param.kernel_type = LINEAR;
00089     param.degree = 3;
00090     param.gamma = 0;    // 1/k
00091     param.coef0 = 0;
00092     param.nu = get_nu();
00093     param.kernel=kernel;
00094     param.cache_size = kernel->get_cache_size();
00095     param.max_train_time = max_train_time;
00096     param.C = get_C1();
00097     param.eps = epsilon;
00098     param.p = 0.1;
00099     param.shrinking = 1;
00100     param.nr_weight = 2;
00101     param.weight_label = weights_label;
00102     param.weight = weights;
00103     param.use_bias = get_bias_enabled();
00104 
00105     const char* error_msg = svm_check_parameter(&problem, &param);
00106 
00107     if(error_msg)
00108         SG_ERROR("Error: %s\n",error_msg);
00109 
00110     model = svm_train(&problem, &param);
00111 
00112     if (model)
00113     {
00114         ASSERT(model->nr_class==2);
00115         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00116 
00117         int32_t num_sv=model->l;
00118 
00119         create_new_model(num_sv);
00120         CSVM::set_objective(model->objective);
00121 
00122         float64_t sgn=model->label[0];
00123 
00124         set_bias(-sgn*model->rho[0]);
00125 
00126         for (int32_t i=0; i<num_sv; i++)
00127         {
00128             set_support_vector(i, (model->SV[i])->index);
00129             set_alpha(i, sgn*model->sv_coef[0][i]);
00130         }
00131 
00132         SG_FREE(problem.x);
00133         SG_FREE(problem.y);
00134         SG_FREE(problem.pv);
00135         SG_FREE(problem.C);
00136 
00137 
00138         SG_FREE(x_space);
00139 
00140         svm_destroy_model(model);
00141         model=NULL;
00142         return true;
00143     }
00144     else
00145         return false;
00146 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation