31 REQUIRE(model != NULL && labs != NULL,
32 "%s::CFWSOSVM(): model and labels cannot be NULL!\n",
get_name());
35 "%s::CFWSOSVM(): number of labels should be greater than 0!\n",
get_name());
39 m_do_line_search = do_line_search;
53 m_do_line_search =
true;
54 m_gap_threshold = 0.1;
69 SG_DEBUG(
"Entering CFWSOSVM::train_machine.\n");
77 SG_DEBUG(
"The training setup is correct.\n");
107 for (int32_t pi = 0; pi < m_num_iter; ++pi)
114 for (int32_t si = 0; si < N; ++si)
136 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
152 w_s.
scale(1.0 / (N*m_lambda));
168 SG_SPRINT(
"pass %d (iteration %d), primal = %f, dual = %f, duality gap = %f, train_error = %f \n",
169 pi, k, primal, dual, dual_gap, train_error);
175 if (dual_gap <= m_gap_threshold)
178 SG_DEBUG(
"current gap: %f, gap_threshold: %f\n", dual_gap, m_gap_threshold);
179 SG_DEBUG(
"Duality gap below threshold -- stopping!\n");
185 SG_DEBUG(
"current gap: %f.\n", dual_gap);
190 if (m_do_line_search)
192 gamma = dual_gap / (m_lambda \
194 gamma = ((gamma > 1 ? 1 : gamma) < 0) ? 0 : gamma;
199 m_ell = (1.0-gamma) * m_ell + gamma * ell_s;
206 SG_DEBUG(
"Leaving CFWSOSVM::train_machine.\n");
227 m_num_iter = num_iter;
232 return m_gap_threshold;
237 m_gap_threshold = gap_threshold;
SGVector< float64_t > psi_truth
Base class of the labels used in Structured Output (SO) problems.
void set_ell(float64_t ell)
static float64_t primal_objective(SGVector< float64_t > w, CStructuredModel *model, float64_t lbda)
CStructuredModel * m_model
virtual int32_t get_dim() const =0
static bool fequals_abs(const T &a, const T &b, const float64_t eps)
void scale(T alpha)
Scale vector inplace.
float64_t get_lambda() const
int32_t get_num_iter() const
virtual bool train_machine(CFeatures *data=NULL)
virtual void init_training()
void set_lambda(float64_t lbda)
float64_t get_gap_threshold() const
void set_features(CFeatures *f)
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
static float64_t dual_objective(SGVector< float64_t > w, float64_t aloss, float64_t lbda)
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
void set_gap_threshold(float64_t gap_threshold)
float64_t get_ell() const
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
virtual bool check_training_setup() const
static float64_t average_loss(SGVector< float64_t > w, CStructuredModel *model, bool is_ub=false)
Class CStructuredModel that represents the application specific model and contains most of the applic...
all of classes and functions are contained in the shogun namespace
virtual const char * get_name() const
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
virtual void add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, float64_t dual=-1, float64_t dgap=-1)
virtual int32_t get_num_labels() const
The class Features is the base class of all feature objects.
SGVector< T > clone() const
SGVector< float64_t > m_w
void set_num_iter(int32_t num_iter)
SGVector< float64_t > psi_pred
SGSparseVector< float64_t > psi_truth_sparse
static CStructuredLabels * to_structured(CLabels *base_labels)
virtual const char * get_name() const
virtual EMachineType get_classifier_type()
SGSparseVector< float64_t > psi_pred_sparse
void add(const SGVector< T > x)