LibSVR.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/regression/svr/LibSVR.h>
00012 #include <shogun/labels/RegressionLabels.h>
00013 #include <shogun/io/SGIO.h>
00014 
00015 using namespace shogun;
00016 
00017 CLibSVR::CLibSVR()
00018 : CSVM()
00019 {
00020     model=NULL;
00021 }
00022 
00023 CLibSVR::CLibSVR(float64_t C, float64_t eps, CKernel* k, CLabels* lab)
00024 : CSVM()
00025 {
00026     model=NULL;
00027 
00028     set_C(C,C);
00029     set_tube_epsilon(eps);
00030     set_labels(lab);
00031     set_kernel(k);
00032 }
00033 
00034 CLibSVR::~CLibSVR()
00035 {
00036     SG_FREE(model);
00037 }
00038 
00039 EMachineType CLibSVR::get_classifier_type()
00040 {
00041     return CT_LIBSVR;
00042 }
00043 
00044 bool CLibSVR::train_machine(CFeatures* data)
00045 {
00046     ASSERT(kernel);
00047     ASSERT(m_labels && m_labels->get_num_labels());
00048     ASSERT(m_labels->get_label_type() == LT_REGRESSION);
00049 
00050     if (data)
00051     {
00052         if (m_labels->get_num_labels() != data->get_num_vectors())
00053             SG_ERROR("Number of training vectors does not match number of labels\n");
00054         kernel->init(data, data);
00055     }
00056 
00057     SG_FREE(model);
00058 
00059     struct svm_node* x_space;
00060 
00061     problem.l=m_labels->get_num_labels();
00062     SG_INFO( "%d trainlabels\n", problem.l);
00063 
00064     problem.y=SG_MALLOC(float64_t, problem.l);
00065     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00066     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00067 
00068     for (int32_t i=0; i<problem.l; i++)
00069     {
00070         problem.y[i]=((CRegressionLabels*) m_labels)->get_label(i);
00071         problem.x[i]=&x_space[2*i];
00072         x_space[2*i].index=i;
00073         x_space[2*i+1].index=-1;
00074     }
00075 
00076     int32_t weights_label[2]={-1,+1};
00077     float64_t weights[2]={1.0,get_C2()/get_C1()};
00078 
00079     param.svm_type=EPSILON_SVR; // epsilon SVR
00080     param.kernel_type = LINEAR;
00081     param.degree = 3;
00082     param.gamma = 0;    // 1/k
00083     param.coef0 = 0;
00084     param.nu = 0.5;
00085     param.kernel=kernel;
00086     param.cache_size = kernel->get_cache_size();
00087     param.max_train_time = m_max_train_time;
00088     param.C = get_C1();
00089     param.eps = epsilon;
00090     param.p = tube_epsilon;
00091     param.shrinking = 1;
00092     param.nr_weight = 2;
00093     param.weight_label = weights_label;
00094     param.weight = weights;
00095     param.use_bias = get_bias_enabled();
00096 
00097     const char* error_msg = svm_check_parameter(&problem,&param);
00098 
00099     if(error_msg)
00100         SG_ERROR("Error: %s\n",error_msg);
00101 
00102     model = svm_train(&problem, &param);
00103 
00104     if (model)
00105     {
00106         ASSERT(model->nr_class==2);
00107         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00108 
00109         int32_t num_sv=model->l;
00110 
00111         create_new_model(num_sv);
00112 
00113         CSVM::set_objective(model->objective);
00114 
00115         set_bias(-model->rho[0]);
00116 
00117         for (int32_t i=0; i<num_sv; i++)
00118         {
00119             set_support_vector(i, (model->SV[i])->index);
00120             set_alpha(i, model->sv_coef[0][i]);
00121         }
00122 
00123         SG_FREE(problem.x);
00124         SG_FREE(problem.y);
00125         SG_FREE(x_space);
00126 
00127         svm_destroy_model(model);
00128         model=NULL;
00129         return true;
00130     }
00131     else
00132         return false;
00133 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation