22 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
29 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
38 void CPrimalMosekSOSVM::init()
47 m_regularization = 1.0;
51 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
55 bool CPrimalMosekSOSVM::train_machine(
CFeatures* data)
57 SG_DEBUG(
"Entering CPrimalMosekSOSVM::train_machine.\n");
61 CFeatures* model_features = get_features();
63 m_model->init_training();
65 m_model->check_training_setup();
66 SG_DEBUG(
"The training setup is correct.\n");
69 int32_t M = m_model->get_dim();
71 int32_t num_aux = m_model->get_num_aux();
73 int32_t num_aux_con = m_model->get_num_aux_con();
77 SG_DEBUG(
"M=%d, N =%d, num_aux=%d, num_aux_con=%d.\n", M, N, num_aux, num_aux_con);
80 CMosek* mosek =
new CMosek(0, M+num_aux+N);
82 REQUIRE(mosek->get_rescode() == MSK_RES_OK,
"Mosek object could not be properly created in PrimalMosekSOSVM training.\n");
87 m_model->init_primal_opt(m_regularization, A, a, B, b, lb, ub, C);
90 "%s::train_machine(): lb.vlen can only be 0 or w.vlen!\n", get_name());
93 "%s::train_machine(): ub.vlen can only be 0 or w.vlen!\n", get_name());
101 SG_DEBUG(
"Regularization used in PrimalMosekSOSVM equal to %.2f.\n", m_regularization);
104 REQUIRE(mosek->init_sosvm(M, N, num_aux, num_aux_con, C, m_lb, m_ub, A, b) == MSK_RES_OK,
105 "Mosek error in PrimalMosekSOSVM initializing SO-SVM.\n")
119 for ( int32_t i = 0 ; i < N ; ++i )
126 int32_t num_con = num_aux_con;
127 int32_t old_num_con = num_con;
128 bool exception =
false;
138 SG_DEBUG(
"Iteration #%d: Cutting plane training with num_con=%d and old_num_con=%d.\n",
139 iteration, num_con, old_num_con);
141 old_num_con = num_con;
143 for ( int32_t i = 0 ; i < N ; ++i )
160 while ( cur_res != NULL )
168 if ( slack > max_slack + m_epsilon )
172 if ( ! insert_result(cur_list, result) )
178 add_constraint(mosek, result, num_con, i);
185 if ( ! insert_result(cur_list, result) )
191 add_constraint(mosek, result, num_con, i);
200 SG_DEBUG(
"Entering Mosek QP solver.\n");
202 mosek->optimize(sol);
203 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
207 else if ( i < M+num_aux )
210 m_slacks[i-M-num_aux] = sol[i];
213 SG_DEBUG(
"QP solved. The primal objective value is %.4f.\n", mosek->get_primal_objective_value());
217 }
while ( old_num_con != num_con && ! exception );
219 po_value = mosek->get_primal_objective_value();
231 int32_t M = m_w.vlen;
247 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
248 "to be set true\n", m_model->get_name());
253 bool CPrimalMosekSOSVM::insert_result(
CList* result_list,
CResultSet* result)
const
259 SG_PRINT(
"ResultSet could not be inserted in the list..."
260 "aborting training of PrimalMosekSOSVM\n");
266 bool CPrimalMosekSOSVM::add_constraint(
272 int32_t M = m_model->get_dim();
277 for (
int i = 0 ; i < M ; ++i )
288 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
289 "to be set true\n", m_model->get_name());
292 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
293 m_model->get_num_aux(), -result->
delta) == MSK_RES_OK );
297 float64_t CPrimalMosekSOSVM::compute_primal_objective()
const
307 void CPrimalMosekSOSVM::set_regularization(
float64_t C)
309 m_regularization = C;
SGVector< float64_t > psi_truth
float64_t loss(float64_t prediction, float64_t label)
Base class of the labels used in Structured Output (SO) problems.
virtual bool init(CFeatures *features)=0
CSGObject * get_next_element()
static const float64_t INFTY
infinity
virtual int32_t get_num_vectors() const =0
static const float64_t epsilon
CSGObject * get_first_element()
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
int32_t get_num_elements()
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
T dense_dot(T alpha, T *vec, int32_t dim, T b)
Class CStructuredModel that represents the application specific model and contains most of the applic...
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
SGVector< T > clone() const
SGVector< float64_t > psi_pred
CSGObject * get_element(int32_t index) const
CHingeLoss implements the hinge loss function.
SGSparseVector< float64_t > psi_truth_sparse
void push_back(CSGObject *e)
void set_epsilon(float *begin, float max)
SGSparseVector< float64_t > psi_pred_sparse
Class List implements a doubly connected list for low-level-objects.
bool insert_element(CSGObject *data)