Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/LibSVM.h>
00012 #include <shogun/io/SGIO.h>
00013
00014 using namespace shogun;
00015
00016 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
00017 : CSVM(), model(NULL), solver_type(st)
00018 {
00019 }
00020
00021 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00023 {
00024 problem = svm_problem();
00025 }
00026
00027 CLibSVM::~CLibSVM()
00028 {
00029 }
00030
00031
00032 bool CLibSVM::train_machine(CFeatures* data)
00033 {
00034 struct svm_node* x_space;
00035
00036 ASSERT(labels && labels->get_num_labels());
00037 ASSERT(labels->is_two_class_labeling());
00038
00039 if (data)
00040 {
00041 if (labels->get_num_labels() != data->get_num_vectors())
00042 SG_ERROR("Number of training vectors does not match number of labels\n");
00043 kernel->init(data, data);
00044 }
00045
00046 problem.l=labels->get_num_labels();
00047 SG_INFO( "%d trainlabels\n", problem.l);
00048
00049
00050 if (m_linear_term.vlen>0)
00051 {
00052 if (labels->get_num_labels()!=m_linear_term.vlen)
00053 SG_ERROR("Number of training vectors does not match length of linear term\n");
00054
00055
00056 problem.pv = get_linear_term_array();
00057 }
00058 else
00059 {
00060
00061 problem.pv = SG_MALLOC(float64_t, problem.l);
00062
00063 for (int i=0; i!=problem.l; i++)
00064 problem.pv[i] = -1.0;
00065 }
00066
00067 problem.y=SG_MALLOC(float64_t, problem.l);
00068 problem.x=SG_MALLOC(struct svm_node*, problem.l);
00069 problem.C=SG_MALLOC(float64_t, problem.l);
00070
00071 x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00072
00073 for (int32_t i=0; i<problem.l; i++)
00074 {
00075 problem.y[i]=labels->get_label(i);
00076 problem.x[i]=&x_space[2*i];
00077 x_space[2*i].index=i;
00078 x_space[2*i+1].index=-1;
00079 }
00080
00081 int32_t weights_label[2]={-1,+1};
00082 float64_t weights[2]={1.0,get_C2()/get_C1()};
00083
00084 ASSERT(kernel && kernel->has_features());
00085 ASSERT(kernel->get_num_vec_lhs()==problem.l);
00086
00087 param.svm_type=solver_type;
00088 param.kernel_type = LINEAR;
00089 param.degree = 3;
00090 param.gamma = 0;
00091 param.coef0 = 0;
00092 param.nu = get_nu();
00093 param.kernel=kernel;
00094 param.cache_size = kernel->get_cache_size();
00095 param.max_train_time = max_train_time;
00096 param.C = get_C1();
00097 param.eps = epsilon;
00098 param.p = 0.1;
00099 param.shrinking = 1;
00100 param.nr_weight = 2;
00101 param.weight_label = weights_label;
00102 param.weight = weights;
00103 param.use_bias = get_bias_enabled();
00104
00105 const char* error_msg = svm_check_parameter(&problem, ¶m);
00106
00107 if(error_msg)
00108 SG_ERROR("Error: %s\n",error_msg);
00109
00110 model = svm_train(&problem, ¶m);
00111
00112 if (model)
00113 {
00114 ASSERT(model->nr_class==2);
00115 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00116
00117 int32_t num_sv=model->l;
00118
00119 create_new_model(num_sv);
00120 CSVM::set_objective(model->objective);
00121
00122 float64_t sgn=model->label[0];
00123
00124 set_bias(-sgn*model->rho[0]);
00125
00126 for (int32_t i=0; i<num_sv; i++)
00127 {
00128 set_support_vector(i, (model->SV[i])->index);
00129 set_alpha(i, sgn*model->sv_coef[0][i]);
00130 }
00131
00132 SG_FREE(problem.x);
00133 SG_FREE(problem.y);
00134 SG_FREE(problem.pv);
00135 SG_FREE(problem.C);
00136
00137
00138 SG_FREE(x_space);
00139
00140 svm_destroy_model(model);
00141 model=NULL;
00142 return true;
00143 }
00144 else
00145 return false;
00146 }