Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/MulticlassLibSVM.h>
00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/io/SGIO.h>
00015
00016 using namespace shogun;
00017
00018 CMulticlassLibSVM::CMulticlassLibSVM(LIBSVM_SOLVER_TYPE st)
00019 : CMulticlassSVM(new CMulticlassOneVsOneStrategy()), model(NULL), solver_type(st)
00020 {
00021 }
00022
00023 CMulticlassLibSVM::CMulticlassLibSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CMulticlassSVM(new CMulticlassOneVsOneStrategy(), C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00025 {
00026 }
00027
00028 CMulticlassLibSVM::~CMulticlassLibSVM()
00029 {
00030 }
00031
00032 bool CMulticlassLibSVM::train_machine(CFeatures* data)
00033 {
00034 struct svm_node* x_space;
00035
00036 problem = svm_problem();
00037
00038 ASSERT(m_labels && m_labels->get_num_labels());
00039 ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00040 int32_t num_classes = m_multiclass_strategy->get_num_classes();
00041 problem.l=m_labels->get_num_labels();
00042 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00043
00044
00045 if (data)
00046 {
00047 if (m_labels->get_num_labels() != data->get_num_vectors())
00048 {
00049 SG_ERROR("Number of training vectors does not match number of "
00050 "labels\n");
00051 }
00052 m_kernel->init(data, data);
00053 }
00054
00055 problem.y=SG_MALLOC(float64_t, problem.l);
00056 problem.x=SG_MALLOC(struct svm_node*, problem.l);
00057 problem.pv=SG_MALLOC(float64_t, problem.l);
00058 problem.C=SG_MALLOC(float64_t, problem.l);
00059
00060 x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00061
00062 for (int32_t i=0; i<problem.l; i++)
00063 {
00064 problem.pv[i]=-1.0;
00065 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00066 problem.x[i]=&x_space[2*i];
00067 x_space[2*i].index=i;
00068 x_space[2*i+1].index=-1;
00069 }
00070
00071 ASSERT(m_kernel);
00072
00073 param.svm_type=solver_type;
00074 param.kernel_type = LINEAR;
00075 param.degree = 3;
00076 param.gamma = 0;
00077 param.coef0 = 0;
00078 param.nu = get_nu();
00079 param.kernel=m_kernel;
00080 param.cache_size = m_kernel->get_cache_size();
00081 param.max_train_time = m_max_train_time;
00082 param.C = get_C();
00083 param.eps = get_epsilon();
00084 param.p = 0.1;
00085 param.shrinking = 1;
00086 param.nr_weight = 0;
00087 param.weight_label = NULL;
00088 param.weight = NULL;
00089 param.use_bias = svm_proto()->get_bias_enabled();
00090
00091 const char* error_msg = svm_check_parameter(&problem,¶m);
00092
00093 if(error_msg)
00094 SG_ERROR("Error: %s\n",error_msg);
00095
00096 model = svm_train(&problem, ¶m);
00097
00098 if (model)
00099 {
00100 if (model->nr_class!=num_classes)
00101 {
00102 SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n",
00103 model->nr_class, num_classes);
00104 }
00105 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00106 create_multiclass_svm(num_classes);
00107
00108 int32_t* offsets=SG_MALLOC(int32_t, num_classes);
00109 offsets[0]=0;
00110
00111 for (int32_t i=1; i<num_classes; i++)
00112 offsets[i] = offsets[i-1]+model->nSV[i-1];
00113
00114 int32_t s=0;
00115 for (int32_t i=0; i<num_classes; i++)
00116 {
00117 for (int32_t j=i+1; j<num_classes; j++)
00118 {
00119 int32_t k, l;
00120
00121 float64_t sgn=1;
00122 if (model->label[i]>model->label[j])
00123 sgn=-1;
00124
00125 int32_t num_sv=model->nSV[i]+model->nSV[j];
00126 float64_t bias=-model->rho[s];
00127
00128 ASSERT(num_sv>0);
00129 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00130
00131 CSVM* svm=new CSVM(num_sv);
00132
00133 svm->set_bias(sgn*bias);
00134
00135 int32_t sv_idx=0;
00136 for (k=0; k<model->nSV[i]; k++)
00137 {
00138 SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00139 model->SV[offsets[i]+k]->index);
00140 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00141 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00142 sv_idx++;
00143 }
00144
00145 for (k=0; k<model->nSV[j]; k++)
00146 {
00147 SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00148 model->SV[offsets[i]+k]->index);
00149 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00150 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00151 sv_idx++;
00152 }
00153
00154 int32_t idx=0;
00155
00156 if (num_classes > 3)
00157 {
00158 if (sgn>0)
00159 {
00160 for (k=0; k<model->label[i]; k++)
00161 idx+=num_classes-k-1;
00162
00163 for (l=model->label[i]+1; l<model->label[j]; l++)
00164 idx++;
00165 }
00166 else
00167 {
00168 for (k=0; k<model->label[j]; k++)
00169 idx+=num_classes-k-1;
00170
00171 for (l=model->label[j]+1; l<model->label[i]; l++)
00172 idx++;
00173 }
00174 }
00175 else if (num_classes == 3)
00176 {
00177 idx = model->label[j]+model->label[i] - 3;
00178 }
00179 else if (num_classes == 2)
00180 {
00181 idx = i;
00182 }
00183
00184
00185
00186
00187
00188
00189 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f "
00190 "label:(%d,%d) -> svm[%d]\n",
00191 s, num_sv, model->l, bias, model->label[i],
00192 model->label[j], idx);
00193
00194 REQUIRE(set_svm(idx, svm),"SVM set failed");
00195 s++;
00196 }
00197 }
00198
00199 set_objective(model->objective);
00200
00201 SG_FREE(offsets);
00202 SG_FREE(problem.x);
00203 SG_FREE(problem.y);
00204 SG_FREE(x_space);
00205 SG_FREE(problem.pv);
00206 SG_FREE(problem.C);
00207
00208 svm_destroy_model(model);
00209 model=NULL;
00210
00211 return true;
00212 }
00213 else
00214 return false;
00215 }
00216