MulticlassOCAS.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009-2012 Vojtech Franc and Soeren Sonnenburg
00008  * Written (W) 2012 Sergey Lisitsyn
00009  * Copyright (C) 2009-2012 Vojtech Franc and Soeren Sonnenburg
00010  */
00011 
00012 #include <shogun/multiclass/MulticlassOCAS.h>
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/labels/MulticlassLabels.h>
00016 
00017 using namespace shogun;
00018 
00019 struct mocas_data
00020 {
00021     CDotFeatures* features;
00022     float64_t* W;
00023     float64_t* oldW;
00024     float64_t* full_A;
00025     float64_t* data_y;
00026     float64_t* output_values;
00027     uint32_t nY;
00028     uint32_t nData;
00029     uint32_t nDim;
00030     float64_t* new_a;
00031 };
00032 
00033 CMulticlassOCAS::CMulticlassOCAS() :
00034     CLinearMulticlassMachine()
00035 {
00036     register_parameters();
00037     set_C(1.0);
00038     set_epsilon(1e-2);
00039     set_max_iter(1000000);
00040     set_method(1);
00041     set_buf_size(5000);
00042 }
00043 
00044 CMulticlassOCAS::CMulticlassOCAS(float64_t C, CDotFeatures* train_features, CLabels* train_labels) :
00045     CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(), train_features, NULL, train_labels), m_C(C)
00046 {
00047     register_parameters();
00048     set_epsilon(1e-2);
00049     set_max_iter(1000000);
00050     set_method(1);
00051     set_buf_size(5000);
00052 }
00053 
00054 void CMulticlassOCAS::register_parameters()
00055 {
00056     SG_ADD(&m_C, "m_C", "regularization constant", MS_AVAILABLE);
00057     SG_ADD(&m_epsilon, "m_epsilon", "solver relative tolerance", MS_NOT_AVAILABLE);
00058     SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations", MS_NOT_AVAILABLE);
00059     SG_ADD(&m_method, "m_method", "used solver method", MS_NOT_AVAILABLE);
00060     SG_ADD(&m_buf_size, "m_buf_size", "buffer size", MS_NOT_AVAILABLE);
00061 }
00062 
00063 CMulticlassOCAS::~CMulticlassOCAS()
00064 {
00065 }
00066 
00067 bool CMulticlassOCAS::train_machine(CFeatures* data)
00068 {
00069     if (data)
00070         set_features((CDotFeatures*)data);
00071 
00072     ASSERT(m_features);
00073     ASSERT(m_labels);
00074     ASSERT(m_multiclass_strategy);
00075 
00076     int32_t num_vectors = m_features->get_num_vectors();
00077     int32_t num_classes = m_multiclass_strategy->get_num_classes();
00078     int32_t num_features = m_features->get_dim_feature_space();
00079 
00080     float64_t C = m_C;
00081     SGVector<float64_t> labels = ((CMulticlassLabels*) m_labels)->get_labels();
00082     uint32_t nY = num_classes;
00083     uint32_t nData = num_vectors;
00084     float64_t TolRel = m_epsilon;
00085     float64_t TolAbs = 0.0;
00086     float64_t QPBound = 0.0;
00087     float64_t MaxTime = m_max_train_time;
00088     uint32_t BufSize = m_buf_size;
00089     uint8_t Method = m_method;
00090 
00091     mocas_data user_data;
00092     user_data.features = m_features;
00093     user_data.W = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00094     user_data.oldW = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00095     user_data.new_a = SG_MALLOC(float64_t, (int64_t)num_features*num_classes);
00096     user_data.full_A = SG_MALLOC(float64_t, (int64_t)num_features*num_classes*m_buf_size);
00097     user_data.output_values = SG_MALLOC(float64_t, num_vectors);
00098     user_data.data_y = labels.vector;
00099     user_data.nY = num_classes;
00100     user_data.nDim = num_features;
00101     user_data.nData = num_vectors;
00102 
00103     ocas_return_value_T value =
00104     msvm_ocas_solver(C, labels.vector, nY, nData, TolRel, TolAbs,
00105                      QPBound, MaxTime, BufSize, Method,
00106                      &CMulticlassOCAS::msvm_full_compute_W,
00107                      &CMulticlassOCAS::msvm_update_W,
00108                      &CMulticlassOCAS::msvm_full_add_new_cut,
00109                      &CMulticlassOCAS::msvm_full_compute_output,
00110                      &CMulticlassOCAS::msvm_sort_data,
00111                      &CMulticlassOCAS::msvm_print,
00112                      &user_data);
00113 
00114     SG_DEBUG("Number of iterations [nIter] = %d \n",value.nIter);
00115     SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes);
00116     SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha);
00117     SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err);
00118     SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P);
00119     SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D);
00120     SG_DEBUG("Output time [output_time] = %f \n",value.output_time);
00121     SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time);
00122     SG_DEBUG("Add time [add_time] = %f \n",value.add_time);
00123     SG_DEBUG("W time [w_time] = %f \n",value.w_time);
00124     SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time);
00125     SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time);
00126     SG_DEBUG("Print time [print_time] = %f \n",value.print_time);
00127     SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag);
00128     SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag);
00129 
00130     m_machines->reset_array();
00131     for (int32_t i=0; i<num_classes; i++)
00132     {
00133         CLinearMachine* machine = new CLinearMachine();
00134         machine->set_w(SGVector<float64_t>(&user_data.W[i*num_features],num_features,false).clone());
00135 
00136         m_machines->push_back(machine);
00137     }
00138 
00139     SG_FREE(user_data.W);
00140     SG_FREE(user_data.oldW);
00141     SG_FREE(user_data.new_a);
00142     SG_FREE(user_data.full_A);
00143     SG_FREE(user_data.output_values);
00144 
00145     return true;
00146 }
00147 
00148 float64_t CMulticlassOCAS::msvm_update_W(float64_t t, void* user_data)
00149 {
00150     float64_t* W = ((mocas_data*)user_data)->W;
00151     float64_t* oldW = ((mocas_data*)user_data)->oldW;
00152     uint32_t nY = ((mocas_data*)user_data)->nY;
00153     uint32_t nDim = ((mocas_data*)user_data)->nDim;
00154 
00155     for(uint32_t j=0; j < nY*nDim; j++)
00156         W[j] = oldW[j]*(1-t) + t*W[j];
00157 
00158     float64_t sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY);
00159 
00160     return sq_norm_W;
00161 }
00162 
00163 void CMulticlassOCAS::msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW,
00164                                           float64_t *alpha, uint32_t nSel, void* user_data)
00165 {
00166     float64_t* W = ((mocas_data*)user_data)->W;
00167     float64_t* oldW = ((mocas_data*)user_data)->oldW;
00168     float64_t* full_A = ((mocas_data*)user_data)->full_A;
00169     uint32_t nY = ((mocas_data*)user_data)->nY;
00170     uint32_t nDim = ((mocas_data*)user_data)->nDim;
00171 
00172     uint32_t i,j;
00173 
00174     memcpy(oldW, W, sizeof(float64_t)*nDim*nY);
00175     memset(W, 0, sizeof(float64_t)*nDim*nY);
00176 
00177     for(i=0; i<nSel; i++)
00178     {
00179         if(alpha[i] > 0)
00180         {
00181             for(j=0; j<nDim*nY; j++)
00182                 W[j] += alpha[i]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
00183         }
00184     }
00185 
00186     *sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY);
00187     *dp_WoldW = SGVector<float64_t>::dot(W,oldW,nDim*nY);
00188 
00189     return;
00190 }
00191 
00192 int CMulticlassOCAS::msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut,
00193                                            uint32_t nSel, void* user_data)
00194 {
00195     float64_t* full_A = ((mocas_data*)user_data)->full_A;
00196     float64_t* new_a = ((mocas_data*)user_data)->new_a;
00197     float64_t* data_y = ((mocas_data*)user_data)->data_y;
00198     uint32_t nY = ((mocas_data*)user_data)->nY;
00199     uint32_t nDim = ((mocas_data*)user_data)->nDim;
00200     uint32_t nData = ((mocas_data*)user_data)->nData;
00201     CDotFeatures* features = ((mocas_data*)user_data)->features;
00202 
00203     float64_t sq_norm_a;
00204     uint32_t i, j, y, y2;
00205 
00206     memset(new_a, 0, sizeof(float64_t)*nDim*nY);
00207 
00208     for(i=0; i < nData; i++)
00209     {
00210         y = (uint32_t)(data_y[i]);
00211         y2 = (uint32_t)new_cut[i];
00212         if(y2 != y)
00213         {
00214             features->add_to_dense_vec(1.0,i,&new_a[nDim*y],nDim);
00215             features->add_to_dense_vec(-1.0,i,&new_a[nDim*y2],nDim);
00216         }
00217     }
00218 
00219     // compute new_a'*new_a and insert new_a to the last column of full_A
00220     sq_norm_a = SGVector<float64_t>::dot(new_a,new_a,nDim*nY);
00221     for(j=0; j < nDim*nY; j++ )
00222         full_A[LIBOCAS_INDEX(j,nSel,nDim*nY)] = new_a[j];
00223 
00224     new_col_H[nSel] = sq_norm_a;
00225     for(i=0; i < nSel; i++)
00226     {
00227         float64_t tmp = 0;
00228 
00229         for(j=0; j < nDim*nY; j++ )
00230             tmp += new_a[j]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
00231 
00232         new_col_H[i] = tmp;
00233     }
00234 
00235     return 0;
00236 }
00237 
00238 int CMulticlassOCAS::msvm_full_compute_output(float64_t *output, void* user_data)
00239 {
00240     float64_t* W = ((mocas_data*)user_data)->W;
00241     uint32_t nY = ((mocas_data*)user_data)->nY;
00242     uint32_t nDim = ((mocas_data*)user_data)->nDim;
00243     uint32_t nData = ((mocas_data*)user_data)->nData;
00244     float64_t* output_values = ((mocas_data*)user_data)->output_values;
00245     CDotFeatures* features = ((mocas_data*)user_data)->features;
00246 
00247     uint32_t i, y;
00248 
00249     for(y=0; y<nY; y++)
00250     {
00251         features->dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0);
00252         for (i=0; i<nData; i++)
00253             output[LIBOCAS_INDEX(y,i,nY)] = output_values[i];
00254     }
00255 
00256     return 0;
00257 }
00258 
00259 int CMulticlassOCAS::msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size)
00260 {
00261     CMath::qsort_index(vals, data, size);
00262     return 0;
00263 }
00264 
00265 void CMulticlassOCAS::msvm_print(ocas_return_value_T value)
00266 {
00267 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation