LibSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/LibSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
00018 : CSVM(), model(NULL), solver_type(st)
00019 {
00020 }
00021 
00022 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00023 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00024 {
00025     problem = svm_problem();
00026 }
00027 
00028 CLibSVM::~CLibSVM()
00029 {
00030 }
00031 
00032 
00033 bool CLibSVM::train_machine(CFeatures* data)
00034 {
00035     struct svm_node* x_space;
00036 
00037     ASSERT(m_labels && m_labels->get_num_labels());
00038     ASSERT(m_labels->get_label_type() == LT_BINARY);
00039 
00040     if (data)
00041     {
00042         if (m_labels->get_num_labels() != data->get_num_vectors())
00043         {
00044             SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00045                     " not match number of labels (%d)\n", get_name(),
00046                     data->get_num_vectors(), m_labels->get_num_labels());
00047         }
00048         kernel->init(data, data);
00049     }
00050 
00051     problem.l=m_labels->get_num_labels();
00052     SG_INFO( "%d trainlabels\n", problem.l);
00053 
00054     // set linear term
00055     if (m_linear_term.vlen>0)
00056     {
00057         if (m_labels->get_num_labels()!=m_linear_term.vlen)
00058             SG_ERROR("Number of training vectors does not match length of linear term\n");
00059 
00060         // set with linear term from base class
00061         problem.pv = get_linear_term_array();
00062     }
00063     else
00064     {
00065         // fill with minus ones
00066         problem.pv = SG_MALLOC(float64_t, problem.l);
00067 
00068         for (int i=0; i!=problem.l; i++)
00069             problem.pv[i] = -1.0;
00070     }
00071 
00072     problem.y=SG_MALLOC(float64_t, problem.l);
00073     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00074     problem.C=SG_MALLOC(float64_t, problem.l);
00075 
00076     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00077 
00078     for (int32_t i=0; i<problem.l; i++)
00079     {
00080         problem.y[i]=((CBinaryLabels*) m_labels)->get_label(i);
00081         problem.x[i]=&x_space[2*i];
00082         x_space[2*i].index=i;
00083         x_space[2*i+1].index=-1;
00084     }
00085 
00086     int32_t weights_label[2]={-1,+1};
00087     float64_t weights[2]={1.0,get_C2()/get_C1()};
00088 
00089     ASSERT(kernel && kernel->has_features());
00090     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00091 
00092     param.svm_type=solver_type; // C SVM or NU_SVM
00093     param.kernel_type = LINEAR;
00094     param.degree = 3;
00095     param.gamma = 0;    // 1/k
00096     param.coef0 = 0;
00097     param.nu = get_nu();
00098     param.kernel=kernel;
00099     param.cache_size = kernel->get_cache_size();
00100     param.max_train_time = m_max_train_time;
00101     param.C = get_C1();
00102     param.eps = epsilon;
00103     param.p = 0.1;
00104     param.shrinking = 1;
00105     param.nr_weight = 2;
00106     param.weight_label = weights_label;
00107     param.weight = weights;
00108     param.use_bias = get_bias_enabled();
00109 
00110     const char* error_msg = svm_check_parameter(&problem, &param);
00111 
00112     if(error_msg)
00113         SG_ERROR("Error: %s\n",error_msg);
00114 
00115     model = svm_train(&problem, &param);
00116 
00117     if (model)
00118     {
00119         ASSERT(model->nr_class==2);
00120         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00121 
00122         int32_t num_sv=model->l;
00123 
00124         create_new_model(num_sv);
00125         CSVM::set_objective(model->objective);
00126 
00127         float64_t sgn=model->label[0];
00128 
00129         set_bias(-sgn*model->rho[0]);
00130 
00131         for (int32_t i=0; i<num_sv; i++)
00132         {
00133             set_support_vector(i, (model->SV[i])->index);
00134             set_alpha(i, sgn*model->sv_coef[0][i]);
00135         }
00136 
00137         SG_FREE(problem.x);
00138         SG_FREE(problem.y);
00139         SG_FREE(problem.pv);
00140         SG_FREE(problem.C);
00141 
00142 
00143         SG_FREE(x_space);
00144 
00145         svm_destroy_model(model);
00146         model=NULL;
00147         return true;
00148     }
00149     else
00150         return false;
00151 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation