27 bool do_weighted_averaging,
31 REQUIRE(model != NULL && labs != NULL,
32 "%s::CStochasticSOSVM(): model and labels cannot be NULL!\n",
get_name());
35 "%s::CStochasticSOSVM(): number of labels should be greater than 0!\n",
get_name());
39 m_do_weighted_averaging = do_weighted_averaging;
43 void CStochasticSOSVM::init()
53 m_do_weighted_averaging =
true;
54 m_debug_multiplier = 0;
69 SG_DEBUG(
"Entering CStochasticSOSVM::train_machine.\n");
77 SG_DEBUG(
"The training setup is correct.\n");
91 if (m_do_weighted_averaging)
104 int32_t debug_iter = 1;
105 if (m_debug_multiplier == 0)
108 m_debug_multiplier = 100;
115 for (int32_t pi = 0; pi < m_num_iter; ++pi)
117 for (int32_t si = 0; si < N; ++si)
144 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
149 w_s.
scale(1.0 / (N*m_lambda));
159 if (m_do_weighted_averaging)
173 if (m_do_weighted_averaging)
174 w_debug = w_avg.
clone();
181 SG_DEBUG(
"pass %d (iteration %d), SVM primal = %f, train_error = %f \n",
182 pi, k, primal, train_error);
186 debug_iter =
CMath::min(debug_iter+N, debug_iter*(1+m_debug_multiplier/100));
191 if (m_do_weighted_averaging)
197 SG_DEBUG(
"Leaving CStochasticSOSVM::train_machine.\n");
218 m_num_iter = num_iter;
223 return m_debug_multiplier;
228 m_debug_multiplier = multiplier;
238 m_rand_seed = rand_seed;
SGVector< float64_t > psi_truth
Base class of the labels used in Structured Output (SO) problems.
uint32_t get_rand_seed() const
int32_t get_debug_multiplier() const
void set_debug_multiplier(int32_t multiplier)
virtual const char * get_name() const
virtual bool train_machine(CFeatures *data=NULL)
static float64_t primal_objective(SGVector< float64_t > w, CStructuredModel *model, float64_t lbda)
CStructuredModel * m_model
virtual int32_t get_dim() const =0
void set_num_iter(int32_t num_iter)
void scale(T alpha)
Scale vector inplace.
void set_rand_seed(uint32_t rand_seed)
virtual void init_training()
void set_features(CFeatures *f)
static void init_random(uint32_t initseed=0)
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
virtual bool check_training_setup() const
static float64_t average_loss(SGVector< float64_t > w, CStructuredModel *model, bool is_ub=false)
Class CStructuredModel that represents the application specific model and contains most of the applic...
all of classes and functions are contained in the shogun namespace
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
virtual void add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, float64_t dual=-1, float64_t dgap=-1)
virtual int32_t get_num_labels() const
The class Features is the base class of all feature objects.
SGVector< T > clone() const
SGVector< float64_t > m_w
void set_lambda(float64_t lbda)
SGVector< float64_t > psi_pred
virtual EMachineType get_classifier_type()
SGSparseVector< float64_t > psi_truth_sparse
static CStructuredLabels * to_structured(CLabels *base_labels)
virtual const char * get_name() const
float64_t get_lambda() const
SGSparseVector< float64_t > psi_pred_sparse
void add(const SGVector< T > x)
int32_t get_num_iter() const