35 CMulticlassOCAS::CMulticlassOCAS() :
38 register_parameters();
41 set_max_iter(1000000);
49 register_parameters();
51 set_max_iter(1000000);
56 void CMulticlassOCAS::register_parameters()
65 CMulticlassOCAS::~CMulticlassOCAS()
69 bool CMulticlassOCAS::train_machine(
CFeatures* data)
76 ASSERT(m_multiclass_strategy)
78 int32_t num_vectors = m_features->get_num_vectors();
79 int32_t num_classes = m_multiclass_strategy->get_num_classes();
80 int32_t num_features = m_features->get_dim_feature_space();
84 uint32_t nY = num_classes;
85 uint32_t nData = num_vectors;
90 uint32_t BufSize = m_buf_size;
91 uint8_t Method = m_method;
94 user_data.features = m_features;
95 user_data.W = SG_CALLOC(
float64_t, (int64_t)num_features*num_classes);
96 user_data.oldW = SG_CALLOC(float64_t, (int64_t)num_features*num_classes);
97 user_data.new_a = SG_CALLOC(float64_t, (int64_t)num_features*num_classes);
98 user_data.full_A = SG_CALLOC(float64_t, (int64_t)num_features*num_classes*m_buf_size);
99 user_data.output_values = SG_CALLOC(float64_t, num_vectors);
100 user_data.data_y = labels.vector;
101 user_data.nY = num_classes;
102 user_data.nDim = num_features;
103 user_data.nData = num_vectors;
105 ocas_return_value_T value =
106 msvm_ocas_solver(C, labels.vector, nY, nData, TolRel, TolAbs,
107 QPBound, MaxTime, BufSize, Method,
108 &CMulticlassOCAS::msvm_full_compute_W,
109 &CMulticlassOCAS::msvm_update_W,
110 &CMulticlassOCAS::msvm_full_add_new_cut,
111 &CMulticlassOCAS::msvm_full_compute_output,
112 &CMulticlassOCAS::msvm_sort_data,
113 &CMulticlassOCAS::msvm_print,
116 SG_DEBUG("Number of iterations [nIter] = %d \n",value.nIter)
117 SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes)
118 SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha)
119 SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err)
120 SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P)
121 SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D)
122 SG_DEBUG("Output time [output_time] = %f \n",value.output_time)
123 SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time)
124 SG_DEBUG("Add time [add_time] = %f \n",value.add_time)
125 SG_DEBUG("W time [w_time] = %f \n",value.w_time)
126 SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time)
127 SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time)
128 SG_DEBUG("Print time [print_time] = %f \n",value.print_time)
129 SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag)
130 SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag)
132 m_machines->reset_array();
133 for (int32_t i=0; i<num_classes; i++)
138 m_machines->push_back(machine);
141 SG_FREE(user_data.W);
142 SG_FREE(user_data.oldW);
143 SG_FREE(user_data.new_a);
144 SG_FREE(user_data.full_A);
145 SG_FREE(user_data.output_values);
150 float64_t CMulticlassOCAS::msvm_update_W(float64_t t,
void* user_data)
152 float64_t* W = ((mocas_data*)user_data)->W;
153 float64_t* oldW = ((mocas_data*)user_data)->oldW;
154 uint32_t nY = ((mocas_data*)user_data)->nY;
155 uint32_t nDim = ((mocas_data*)user_data)->nDim;
157 for(uint32_t j=0; j < nY*nDim; j++)
158 W[j] = oldW[j]*(1-t) + t*W[j];
160 float64_t sq_norm_W =
CMath::dot(W,W,nDim*nY);
165 void CMulticlassOCAS::msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW,
166 float64_t *alpha, uint32_t nSel,
void* user_data)
168 float64_t* W = ((mocas_data*)user_data)->W;
169 float64_t* oldW = ((mocas_data*)user_data)->oldW;
170 float64_t* full_A = ((mocas_data*)user_data)->full_A;
171 uint32_t nY = ((mocas_data*)user_data)->nY;
172 uint32_t nDim = ((mocas_data*)user_data)->nDim;
176 memcpy(oldW, W,
sizeof(float64_t)*nDim*nY);
177 memset(W, 0,
sizeof(float64_t)*nDim*nY);
179 for(i=0; i<nSel; i++)
183 for(j=0; j<nDim*nY; j++)
184 W[j] += alpha[i]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
194 int CMulticlassOCAS::msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut,
195 uint32_t nSel,
void* user_data)
197 float64_t* full_A = ((mocas_data*)user_data)->full_A;
198 float64_t* new_a = ((mocas_data*)user_data)->new_a;
199 float64_t* data_y = ((mocas_data*)user_data)->data_y;
200 uint32_t nY = ((mocas_data*)user_data)->nY;
201 uint32_t nDim = ((mocas_data*)user_data)->nDim;
202 uint32_t nData = ((mocas_data*)user_data)->nData;
203 CDotFeatures* features = ((mocas_data*)user_data)->features;
206 uint32_t i, j, y, y2;
208 memset(new_a, 0,
sizeof(float64_t)*nDim*nY);
210 for(i=0; i < nData; i++)
212 y = (uint32_t)(data_y[i]);
213 y2 = (uint32_t)new_cut[i];
223 for(j=0; j < nDim*nY; j++ )
224 full_A[LIBOCAS_INDEX(j,nSel,nDim*nY)] = new_a[j];
226 new_col_H[nSel] = sq_norm_a;
227 for(i=0; i < nSel; i++)
231 for(j=0; j < nDim*nY; j++ )
232 tmp += new_a[j]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
240 int CMulticlassOCAS::msvm_full_compute_output(float64_t *output,
void* user_data)
242 float64_t* W = ((mocas_data*)user_data)->W;
243 uint32_t nY = ((mocas_data*)user_data)->nY;
244 uint32_t nDim = ((mocas_data*)user_data)->nDim;
245 uint32_t nData = ((mocas_data*)user_data)->nData;
246 float64_t* output_values = ((mocas_data*)user_data)->output_values;
247 CDotFeatures* features = ((mocas_data*)user_data)->features;
253 features->
dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0);
254 for (i=0; i<nData; i++)
255 output[LIBOCAS_INDEX(y,i,nY)] = output_values[i];
261 int CMulticlassOCAS::msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size)
267 void CMulticlassOCAS::msvm_print(ocas_return_value_T value)
271 #endif //USE_GPL_SHOGUN
virtual void dense_dot_range(float64_t *output, int32_t start, int32_t stop, float64_t *alphas, float64_t *vec, int32_t dim, float64_t b)
virtual void set_w(const SGVector< float64_t > src_w)
The class Labels models labels, i.e. class assignments of objects.
static void qsort_index(T1 *output, T2 *index, uint32_t size)
virtual void add_to_dense_vec(float64_t alpha, int32_t vec_idx1, float64_t *vec2, int32_t vec2_len, bool abs_val=false)=0
Features that support dot products among other operations.
Multiclass Labels for multi-class classification.
generic linear multiclass machine
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
SGVector< T > clone() const
void set_epsilon(float *begin, float max)
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...