19 using namespace shogun;
21 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
28 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
37 void CPrimalMosekSOSVM::init()
44 m_regularization = 1.0;
48 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
52 bool CPrimalMosekSOSVM::train_machine(
CFeatures* data)
54 SG_DEBUG(
"Entering CPrimalMosekSOSVM::train_machine.\n");
58 CFeatures* model_features = get_features();
60 m_model->init_training();
62 m_model->check_training_setup();
63 SG_DEBUG(
"The training setup is correct.\n");
66 int32_t M = m_model->get_dim();
68 int32_t num_aux = m_model->get_num_aux();
70 int32_t num_aux_con = m_model->get_num_aux_con();
74 SG_DEBUG(
"M=%d, N =%d, num_aux=%d, num_aux_con=%d.\n", M, N, num_aux, num_aux_con);
77 CMosek* mosek =
new CMosek(0, M+num_aux+N);
79 if ( mosek->get_rescode() != MSK_RES_OK )
81 SG_PRINT(
"Mosek object could not be properly created..."
82 "aborting training of PrimalMosekSOSVM\n");
90 m_model->init_primal_opt(m_regularization, A, a, B, b, lb, ub, C);
92 SG_DEBUG(
"Regularization used in PrimalMosekSOSVM equal to %.2f.\n", m_regularization);
95 if ( mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) != MSK_RES_OK )
113 for ( int32_t i = 0 ; i < N ; ++i )
120 int32_t num_con = num_aux_con;
121 int32_t old_num_con = num_con;
126 CList* cur_list = NULL;
127 bool exception =
false;
137 SG_DEBUG(
"Iteration #%d: Cutting plane training with num_con=%d and old_num_con=%d.\n",
138 iteration, num_con, old_num_con);
140 old_num_con = num_con;
142 for ( int32_t i = 0 ; i < N ; ++i )
145 result = m_model->
argmax(m_w, i);
159 while ( cur_res != NULL )
162 CHingeLoss().loss( compute_loss_arg(cur_res) ));
168 if ( slack > max_slack + m_epsilon )
172 if ( ! insert_result(cur_list, result) )
178 add_constraint(mosek, result, num_con, i);
185 if ( ! insert_result(cur_list, result) )
191 add_constraint(mosek, result, num_con, i);
200 SG_DEBUG(
"Entering Mosek QP solver.\n");
202 mosek->optimize(sol);
203 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
207 else if ( i < M+num_aux )
210 m_slacks[i-M-num_aux] = sol[i];
213 SG_DEBUG(
"QP solved. The primal objective value is %.4f.\n", mosek->get_primal_objective_value());
217 }
while ( old_num_con != num_con && ! exception );
219 po_value = mosek->get_primal_objective_value();
231 int32_t M = m_w.vlen;
238 bool CPrimalMosekSOSVM::insert_result(
CList* result_list,
CResultSet* result)
const
244 SG_PRINT(
"ResultSet could not be inserted in the list..."
245 "aborting training of PrimalMosekSOSVM\n");
251 bool CPrimalMosekSOSVM::add_constraint(
257 int32_t M = m_model->get_dim();
260 for (
int i = 0 ; i < M ; ++i )
263 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
264 m_model->get_num_aux(), -result->
delta) == MSK_RES_OK );
268 float64_t CPrimalMosekSOSVM::compute_primal_objective()
const
278 void CPrimalMosekSOSVM::set_regularization(
float64_t C)
280 m_regularization = C;