00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <shogun/multiclass/MulticlassOCAS.h>
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/labels/MulticlassLabels.h>
00016
00017 using namespace shogun;
00018
00019 struct mocas_data
00020 {
00021 CDotFeatures* features;
00022 float64_t* W;
00023 float64_t* oldW;
00024 float64_t* full_A;
00025 float64_t* data_y;
00026 float64_t* output_values;
00027 uint32_t nY;
00028 uint32_t nData;
00029 uint32_t nDim;
00030 float64_t* new_a;
00031 };
00032
00033 CMulticlassOCAS::CMulticlassOCAS() :
00034 CLinearMulticlassMachine()
00035 {
00036 register_parameters();
00037 set_C(1.0);
00038 set_epsilon(1e-2);
00039 set_max_iter(1000000);
00040 set_method(1);
00041 set_buf_size(5000);
00042 }
00043
00044 CMulticlassOCAS::CMulticlassOCAS(float64_t C, CDotFeatures* train_features, CLabels* train_labels) :
00045 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(), train_features, NULL, train_labels), m_C(C)
00046 {
00047 register_parameters();
00048 set_epsilon(1e-2);
00049 set_max_iter(1000000);
00050 set_method(1);
00051 set_buf_size(5000);
00052 }
00053
00054 void CMulticlassOCAS::register_parameters()
00055 {
00056 SG_ADD(&m_C, "m_C", "regularization constant", MS_AVAILABLE);
00057 SG_ADD(&m_epsilon, "m_epsilon", "solver relative tolerance", MS_NOT_AVAILABLE);
00058 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations", MS_NOT_AVAILABLE);
00059 SG_ADD(&m_method, "m_method", "used solver method", MS_NOT_AVAILABLE);
00060 SG_ADD(&m_buf_size, "m_buf_size", "buffer size", MS_NOT_AVAILABLE);
00061 }
00062
00063 CMulticlassOCAS::~CMulticlassOCAS()
00064 {
00065 }
00066
00067 bool CMulticlassOCAS::train_machine(CFeatures* data)
00068 {
00069 if (data)
00070 set_features((CDotFeatures*)data);
00071
00072 ASSERT(m_features);
00073 ASSERT(m_labels);
00074 ASSERT(m_multiclass_strategy);
00075
00076 int32_t num_vectors = m_features->get_num_vectors();
00077 int32_t num_classes = m_multiclass_strategy->get_num_classes();
00078 int32_t num_features = m_features->get_dim_feature_space();
00079
00080 float64_t C = m_C;
00081 SGVector<float64_t> labels = ((CMulticlassLabels*) m_labels)->get_labels();
00082 uint32_t nY = num_classes;
00083 uint32_t nData = num_vectors;
00084 float64_t TolRel = m_epsilon;
00085 float64_t TolAbs = 0.0;
00086 float64_t QPBound = 0.0;
00087 float64_t MaxTime = m_max_train_time;
00088 uint32_t BufSize = m_buf_size;
00089 uint8_t Method = m_method;
00090
00091 mocas_data user_data;
00092 user_data.features = m_features;
00093 user_data.W = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00094 user_data.oldW = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00095 user_data.new_a = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00096 user_data.full_A = SG_MALLOC(float64_t, (int64_t)num_features*num_classes*m_buf_size);
00097 user_data.output_values = SG_MALLOC(float64_t, num_vectors);
00098 user_data.data_y = labels.vector;
00099 user_data.nY = num_classes;
00100 user_data.nDim = num_features;
00101 user_data.nData = num_vectors;
00102
00103 ocas_return_value_T value =
00104 msvm_ocas_solver(C, labels.vector, nY, nData, TolRel, TolAbs,
00105 QPBound, MaxTime, BufSize, Method,
00106 &CMulticlassOCAS::msvm_full_compute_W,
00107 &CMulticlassOCAS::msvm_update_W,
00108 &CMulticlassOCAS::msvm_full_add_new_cut,
00109 &CMulticlassOCAS::msvm_full_compute_output,
00110 &CMulticlassOCAS::msvm_sort_data,
00111 &CMulticlassOCAS::msvm_print,
00112 &user_data);
00113
00114 SG_DEBUG("Number of iterations [nIter] = %d \n",value.nIter);
00115 SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes);
00116 SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha);
00117 SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err);
00118 SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P);
00119 SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D);
00120 SG_DEBUG("Output time [output_time] = %f \n",value.output_time);
00121 SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time);
00122 SG_DEBUG("Add time [add_time] = %f \n",value.add_time);
00123 SG_DEBUG("W time [w_time] = %f \n",value.w_time);
00124 SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time);
00125 SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time);
00126 SG_DEBUG("Print time [print_time] = %f \n",value.print_time);
00127 SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag);
00128 SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag);
00129
00130 m_machines->reset_array();
00131 for (int32_t i=0; i<num_classes; i++)
00132 {
00133 CLinearMachine* machine = new CLinearMachine();
00134 machine->set_w(SGVector<float64_t>(&user_data.W[i*num_features],num_features,false).clone());
00135
00136 m_machines->push_back(machine);
00137 }
00138
00139 SG_FREE(user_data.W);
00140 SG_FREE(user_data.oldW);
00141 SG_FREE(user_data.new_a);
00142 SG_FREE(user_data.full_A);
00143 SG_FREE(user_data.output_values);
00144
00145 return true;
00146 }
00147
00148 float64_t CMulticlassOCAS::msvm_update_W(float64_t t, void* user_data)
00149 {
00150 float64_t* W = ((mocas_data*)user_data)->W;
00151 float64_t* oldW = ((mocas_data*)user_data)->oldW;
00152 uint32_t nY = ((mocas_data*)user_data)->nY;
00153 uint32_t nDim = ((mocas_data*)user_data)->nDim;
00154
00155 for(uint32_t j=0; j < nY*nDim; j++)
00156 W[j] = oldW[j]*(1-t) + t*W[j];
00157
00158 float64_t sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY);
00159
00160 return sq_norm_W;
00161 }
00162
00163 void CMulticlassOCAS::msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW,
00164 float64_t *alpha, uint32_t nSel, void* user_data)
00165 {
00166 float64_t* W = ((mocas_data*)user_data)->W;
00167 float64_t* oldW = ((mocas_data*)user_data)->oldW;
00168 float64_t* full_A = ((mocas_data*)user_data)->full_A;
00169 uint32_t nY = ((mocas_data*)user_data)->nY;
00170 uint32_t nDim = ((mocas_data*)user_data)->nDim;
00171
00172 uint32_t i,j;
00173
00174 memcpy(oldW, W, sizeof(float64_t)*nDim*nY);
00175 memset(W, 0, sizeof(float64_t)*nDim*nY);
00176
00177 for(i=0; i<nSel; i++)
00178 {
00179 if(alpha[i] > 0)
00180 {
00181 for(j=0; j<nDim*nY; j++)
00182 W[j] += alpha[i]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
00183 }
00184 }
00185
00186 *sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY);
00187 *dp_WoldW = SGVector<float64_t>::dot(W,oldW,nDim*nY);
00188
00189 return;
00190 }
00191
00192 int CMulticlassOCAS::msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut,
00193 uint32_t nSel, void* user_data)
00194 {
00195 float64_t* full_A = ((mocas_data*)user_data)->full_A;
00196 float64_t* new_a = ((mocas_data*)user_data)->new_a;
00197 float64_t* data_y = ((mocas_data*)user_data)->data_y;
00198 uint32_t nY = ((mocas_data*)user_data)->nY;
00199 uint32_t nDim = ((mocas_data*)user_data)->nDim;
00200 uint32_t nData = ((mocas_data*)user_data)->nData;
00201 CDotFeatures* features = ((mocas_data*)user_data)->features;
00202
00203 float64_t sq_norm_a;
00204 uint32_t i, j, y, y2;
00205
00206 memset(new_a, 0, sizeof(float64_t)*nDim*nY);
00207
00208 for(i=0; i < nData; i++)
00209 {
00210 y = (uint32_t)(data_y[i]);
00211 y2 = (uint32_t)new_cut[i];
00212 if(y2 != y)
00213 {
00214 features->add_to_dense_vec(1.0,i,&new_a[nDim*y],nDim);
00215 features->add_to_dense_vec(-1.0,i,&new_a[nDim*y2],nDim);
00216 }
00217 }
00218
00219
00220 sq_norm_a = SGVector<float64_t>::dot(new_a,new_a,nDim*nY);
00221 for(j=0; j < nDim*nY; j++ )
00222 full_A[LIBOCAS_INDEX(j,nSel,nDim*nY)] = new_a[j];
00223
00224 new_col_H[nSel] = sq_norm_a;
00225 for(i=0; i < nSel; i++)
00226 {
00227 float64_t tmp = 0;
00228
00229 for(j=0; j < nDim*nY; j++ )
00230 tmp += new_a[j]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
00231
00232 new_col_H[i] = tmp;
00233 }
00234
00235 return 0;
00236 }
00237
00238 int CMulticlassOCAS::msvm_full_compute_output(float64_t *output, void* user_data)
00239 {
00240 float64_t* W = ((mocas_data*)user_data)->W;
00241 uint32_t nY = ((mocas_data*)user_data)->nY;
00242 uint32_t nDim = ((mocas_data*)user_data)->nDim;
00243 uint32_t nData = ((mocas_data*)user_data)->nData;
00244 float64_t* output_values = ((mocas_data*)user_data)->output_values;
00245 CDotFeatures* features = ((mocas_data*)user_data)->features;
00246
00247 uint32_t i, y;
00248
00249 for(y=0; y<nY; y++)
00250 {
00251 features->dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0);
00252 for (i=0; i<nData; i++)
00253 output[LIBOCAS_INDEX(y,i,nY)] = output_values[i];
00254 }
00255
00256 return 0;
00257 }
00258
00259 int CMulticlassOCAS::msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size)
00260 {
00261 CMath::qsort_index(vals, data, size);
00262 return 0;
00263 }
00264
00265 void CMulticlassOCAS::msvm_print(ocas_return_value_T value)
00266 {
00267 }