Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/LibSVMMultiClass.h>
00012 #include <shogun/io/SGIO.h>
00013
00014 using namespace shogun;
00015
00016 CLibSVMMultiClass::CLibSVMMultiClass(LIBSVM_SOLVER_TYPE st)
00017 : CMultiClassSVM(ONE_VS_ONE), model(NULL), solver_type(st)
00018 {
00019 }
00020
00021 CLibSVMMultiClass::CLibSVMMultiClass(float64_t C, CKernel* k, CLabels* lab)
00022 : CMultiClassSVM(ONE_VS_ONE, C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00023 {
00024 }
00025
00026 CLibSVMMultiClass::~CLibSVMMultiClass()
00027 {
00028
00029 }
00030
00031 bool CLibSVMMultiClass::train_machine(CFeatures* data)
00032 {
00033 struct svm_node* x_space;
00034
00035 problem = svm_problem();
00036
00037 ASSERT(labels && labels->get_num_labels());
00038 int32_t num_classes = labels->get_num_classes();
00039 problem.l=labels->get_num_labels();
00040 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00041
00042 if (data)
00043 {
00044 if (labels->get_num_labels() != data->get_num_vectors())
00045 SG_ERROR("Number of training vectors does not match number of labels\n");
00046 kernel->init(data, data);
00047 }
00048
00049 problem.y=SG_MALLOC(float64_t, problem.l);
00050 problem.x=SG_MALLOC(struct svm_node*, problem.l);
00051 problem.pv=SG_MALLOC(float64_t, problem.l);
00052 problem.C=SG_MALLOC(float64_t, problem.l);
00053
00054 x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00055
00056 for (int32_t i=0; i<problem.l; i++)
00057 {
00058 problem.pv[i]=-1.0;
00059 problem.y[i]=labels->get_label(i);
00060 problem.x[i]=&x_space[2*i];
00061 x_space[2*i].index=i;
00062 x_space[2*i+1].index=-1;
00063 }
00064
00065 ASSERT(kernel);
00066
00067 param.svm_type=solver_type;
00068 param.kernel_type = LINEAR;
00069 param.degree = 3;
00070 param.gamma = 0;
00071 param.coef0 = 0;
00072 param.nu = get_nu();
00073 param.kernel=kernel;
00074 param.cache_size = kernel->get_cache_size();
00075 param.max_train_time = max_train_time;
00076 param.C = get_C1();
00077 param.eps = epsilon;
00078 param.p = 0.1;
00079 param.shrinking = 1;
00080 param.nr_weight = 0;
00081 param.weight_label = NULL;
00082 param.weight = NULL;
00083 param.use_bias = get_bias_enabled();
00084
00085 const char* error_msg = svm_check_parameter(&problem,¶m);
00086
00087 if(error_msg)
00088 SG_ERROR("Error: %s\n",error_msg);
00089
00090 model = svm_train(&problem, ¶m);
00091
00092 if (model)
00093 {
00094 if (model->nr_class!=num_classes)
00095 {
00096 SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n",
00097 model->nr_class, num_classes);
00098 }
00099 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00100 create_multiclass_svm(num_classes);
00101
00102 int32_t* offsets=SG_MALLOC(int32_t, num_classes);
00103 offsets[0]=0;
00104
00105 for (int32_t i=1; i<num_classes; i++)
00106 offsets[i] = offsets[i-1]+model->nSV[i-1];
00107
00108 int32_t s=0;
00109 for (int32_t i=0; i<num_classes; i++)
00110 {
00111 for (int32_t j=i+1; j<num_classes; j++)
00112 {
00113 int32_t k, l;
00114
00115 float64_t sgn=1;
00116 if (model->label[i]>model->label[j])
00117 sgn=-1;
00118
00119 int32_t num_sv=model->nSV[i]+model->nSV[j];
00120 float64_t bias=-model->rho[s];
00121
00122 ASSERT(num_sv>0);
00123 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00124
00125 CSVM* svm=new CSVM(num_sv);
00126
00127 svm->set_bias(sgn*bias);
00128
00129 int32_t sv_idx=0;
00130 for (k=0; k<model->nSV[i]; k++)
00131 {
00132 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00133 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00134 sv_idx++;
00135 }
00136
00137 for (k=0; k<model->nSV[j]; k++)
00138 {
00139 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00140 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00141 sv_idx++;
00142 }
00143
00144 int32_t idx=0;
00145
00146 if (sgn>0)
00147 {
00148 for (k=0; k<model->label[i]; k++)
00149 idx+=num_classes-k-1;
00150
00151 for (l=model->label[i]+1; l<model->label[j]; l++)
00152 idx++;
00153 }
00154 else
00155 {
00156 for (k=0; k<model->label[j]; k++)
00157 idx+=num_classes-k-1;
00158
00159 for (l=model->label[j]+1; l<model->label[i]; l++)
00160 idx++;
00161 }
00162
00163
00164
00165
00166
00167
00168
00169 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f label:(%d,%d) -> svm[%d]\n", s, num_sv, model->l, bias, model->label[i], model->label[j], idx);
00170
00171 set_svm(idx, svm);
00172 s++;
00173 }
00174 }
00175
00176 CSVM::set_objective(model->objective);
00177
00178 SG_FREE(offsets);
00179 SG_FREE(problem.x);
00180 SG_FREE(problem.y);
00181 SG_FREE(x_space);
00182
00183 svm_destroy_model(model);
00184 model=NULL;
00185
00186 return true;
00187 }
00188 else
00189 return false;
00190 }
00191