00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifdef USE_MOSEK
00012
00013 #include <shogun/lib/DynamicObjectArray.h>
00014 #include <shogun/lib/List.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/structure/PrimalMosekSOSVM.h>
00017
00018 using namespace shogun;
00019
00020 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
00021 : CLinearStructuredOutputMachine(),
00022 po_value(0.0)
00023 {
00024 }
00025
00026 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
00027 CStructuredModel* model,
00028 CLossFunction* loss,
00029 CStructuredLabels* labs)
00030 : CLinearStructuredOutputMachine(model, loss, labs),
00031 po_value(0.0)
00032 {
00033 }
00034
00035 void CPrimalMosekSOSVM::init()
00036 {
00037 SG_ADD(&m_slacks, "m_slacks", "Slacks vector", MS_NOT_AVAILABLE);
00038 }
00039
00040 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
00041 {
00042 }
00043
00044 bool CPrimalMosekSOSVM::train_machine(CFeatures* data)
00045 {
00046 if (data)
00047 set_features(data);
00048
00049 CFeatures* model_features = get_features();
00050
00051 m_model->check_training_setup();
00052
00053
00054 int32_t M = m_model->get_dim();
00055
00056 int32_t num_aux = m_model->get_num_aux();
00057
00058 int32_t num_aux_con = m_model->get_num_aux_con();
00059
00060 int32_t N = m_model->get_features()->get_num_vectors();
00061
00062
00063 CMosek* mosek = new CMosek(0, M+num_aux+N);
00064 SG_REF(mosek);
00065 if ( mosek->get_rescode() != MSK_RES_OK )
00066 {
00067 SG_PRINT("Mosek object could not be properly created..."
00068 "aborting training of PrimalMosekSOSVM\n");
00069
00070 return false;
00071 }
00072
00073
00074 SGMatrix< float64_t > A, B, C;
00075 SGVector< float64_t > a, b, lb, ub;
00076 m_model->init_opt(A, a, B, b, lb, ub, C);
00077
00078
00079 if ( mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) != MSK_RES_OK )
00080 {
00081
00082 return false;
00083 }
00084
00085
00086 m_w = SGVector< float64_t >(M);
00087 m_w.zero();
00088
00089 m_slacks = SGVector< float64_t >(N);
00090 m_slacks.zero();
00091
00092
00093
00094
00095 CDynamicObjectArray* results = new CDynamicObjectArray(N);
00096 SG_REF(results);
00097 for ( int32_t i = 0 ; i < N ; ++i )
00098 {
00099 CList* list = new CList(true);
00100 results->push_back(list);
00101 }
00102
00103
00104 int32_t num_con = num_aux_con;
00105 int32_t old_num_con = num_con;
00106 float64_t slack = 0.0;
00107 float64_t max_slack = 0.0;
00108 CResultSet* result = NULL;
00109 CResultSet* cur_res = NULL;
00110 CList* cur_list = NULL;
00111 bool exception = false;
00112
00113 SGVector< float64_t > sol(M+num_aux+N);
00114 sol.zero();
00115
00116 SGVector< float64_t > aux(num_aux);
00117
00118 do
00119 {
00120 old_num_con = num_con;
00121
00122 for ( int32_t i = 0 ; i < N ; ++i )
00123 {
00124
00125 result = m_model->argmax(m_w, i);
00126
00127
00128
00129 slack = m_loss->loss( compute_loss_arg(result) );
00130 cur_list = (CList*) results->get_element(i);
00131
00132
00133 if ( cur_list->get_num_elements() > 0 )
00134 {
00135
00136
00137 cur_res = (CResultSet*) cur_list->get_first_element();
00138 max_slack = -CMath::INFTY;
00139
00140 while ( cur_res != NULL )
00141 {
00142 max_slack = CMath::max(max_slack,
00143 m_loss->loss( compute_loss_arg(cur_res) ));
00144
00145 SG_UNREF(cur_res);
00146 cur_res = (CResultSet*) cur_list->get_next_element();
00147 }
00148
00149 if ( slack > max_slack )
00150 {
00151
00152
00153 if ( ! insert_result(cur_list, result) )
00154 {
00155 exception = true;
00156 break;
00157 }
00158
00159 add_constraint(mosek, result, num_con, i);
00160 ++num_con;
00161 }
00162 }
00163 else
00164 {
00165
00166 if ( ! insert_result(cur_list, result) )
00167 {
00168 exception = true;
00169 break;
00170 }
00171
00172 add_constraint(mosek, result, num_con, i);
00173 ++num_con;
00174 }
00175
00176 SG_UNREF(cur_list);
00177 SG_UNREF(result);
00178 }
00179
00180
00181 mosek->optimize(sol);
00182 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
00183 {
00184 if ( i < M )
00185 m_w[i] = sol[i];
00186 else if ( i < M+num_aux )
00187 aux[i-M] = sol[i];
00188 else
00189 m_slacks[i-M-num_aux] = sol[i];
00190 }
00191
00192 } while ( old_num_con != num_con && ! exception );
00193
00194 po_value = mosek->get_primal_objective_value();
00195
00196
00197 SG_UNREF(results);
00198 SG_UNREF(mosek);
00199 SG_UNREF(model_features);
00200 return true;
00201 }
00202
00203 float64_t CPrimalMosekSOSVM::compute_loss_arg(CResultSet* result) const
00204 {
00205
00206 int32_t M = m_w.vlen;
00207
00208 return SGVector< float64_t >::dot(m_w.vector, result->psi_pred.vector, M) +
00209 result->delta -
00210 SGVector< float64_t >::dot(m_w.vector, result->psi_truth.vector, M);
00211 }
00212
00213 bool CPrimalMosekSOSVM::insert_result(CList* result_list, CResultSet* result) const
00214 {
00215 bool succeed = result_list->insert_element(result);
00216
00217 if ( ! succeed )
00218 {
00219 SG_PRINT("ResultSet could not be inserted in the list..."
00220 "aborting training of PrimalMosekSOSVM\n");
00221 }
00222
00223 return succeed;
00224 }
00225
00226 bool CPrimalMosekSOSVM::add_constraint(
00227 CMosek* mosek,
00228 CResultSet* result,
00229 index_t con_idx,
00230 index_t train_idx) const
00231 {
00232 int32_t M = m_model->get_dim();
00233 SGVector< float64_t > dPsi(M);
00234
00235 for ( int i = 0 ; i < M ; ++i )
00236 dPsi[i] = result->psi_pred[i] - result->psi_truth[i];
00237
00238 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
00239 m_model->get_num_aux(), -result->delta) == MSK_RES_OK );
00240 }
00241
00242
00243 float64_t CPrimalMosekSOSVM::compute_primal_objective() const
00244 {
00245 return po_value;
00246 }
00247
00248 #endif