18 using namespace shogun;
20 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
25 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
33 void CPrimalMosekSOSVM::init()
38 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
42 bool CPrimalMosekSOSVM::train_machine(
CFeatures* data)
47 CFeatures* model_features = get_features();
49 m_model->check_training_setup();
52 int32_t M = m_model->get_dim();
54 int32_t num_aux = m_model->get_num_aux();
56 int32_t num_aux_con = m_model->get_num_aux_con();
61 CMosek* mosek =
new CMosek(0, M+num_aux+N);
63 if ( mosek->get_rescode() != MSK_RES_OK )
65 SG_PRINT(
"Mosek object could not be properly created..."
66 "aborting training of PrimalMosekSOSVM\n");
74 m_model->init_opt(A, a, B, b, lb, ub, C);
77 if ( mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) != MSK_RES_OK )
95 for ( int32_t i = 0 ; i < N ; ++i )
102 int32_t num_con = num_aux_con;
103 int32_t old_num_con = num_con;
108 CList* cur_list = NULL;
109 bool exception =
false;
118 old_num_con = num_con;
120 for ( int32_t i = 0 ; i < N ; ++i )
123 result = m_model->
argmax(m_w, i);
126 slack = m_loss->loss( compute_loss_arg(result) );
137 while ( cur_res != NULL )
140 m_loss->loss( compute_loss_arg(cur_res) ));
146 if ( slack > max_slack )
150 if ( ! insert_result(cur_list, result) )
156 add_constraint(mosek, result, num_con, i);
163 if ( ! insert_result(cur_list, result) )
169 add_constraint(mosek, result, num_con, i);
178 mosek->optimize(sol);
179 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
183 else if ( i < M+num_aux )
186 m_slacks[i-M-num_aux] = sol[i];
189 }
while ( old_num_con != num_con && ! exception );
201 int32_t M = m_w.vlen;
208 bool CPrimalMosekSOSVM::insert_result(
CList* result_list,
CResultSet* result)
const
214 SG_PRINT(
"ResultSet could not be inserted in the list..."
215 "aborting training of PrimalMosekSOSVM\n");
221 bool CPrimalMosekSOSVM::add_constraint(
227 int32_t M = m_model->get_dim();
230 for (
int i = 0 ; i < M ; ++i )
233 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
234 m_model->get_num_aux(), -result->
delta) == MSK_RES_OK );