SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
FWSOSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2014 Shell Hu
8  * Copyright (C) 2014 Shell Hu
9  */
10 
14 #include <shogun/lib/SGVector.h>
15 
16 using namespace shogun;
17 
20 {
21  init();
22 }
23 
25  CStructuredModel* model,
26  CStructuredLabels* labs,
27  bool do_line_search,
28  bool verbose)
29 : CLinearStructuredOutputMachine(model, labs)
30 {
31  REQUIRE(model != NULL && labs != NULL,
32  "%s::CFWSOSVM(): model and labels cannot be NULL!\n", get_name());
33 
34  REQUIRE(labs->get_num_labels() > 0,
35  "%s::CFWSOSVM(): number of labels should be greater than 0!\n", get_name());
36 
37  init();
38  m_lambda = 1.0 / labs->get_num_labels();
39  m_do_line_search = do_line_search;
40  m_verbose = verbose;
41 }
42 
43 void CFWSOSVM::init()
44 {
45  SG_ADD(&m_lambda, "lambda", "Regularization constant", MS_NOT_AVAILABLE);
46  SG_ADD(&m_num_iter, "num_iter", "Number of iterations", MS_NOT_AVAILABLE);
47  SG_ADD(&m_do_line_search, "do_line_search", "Do line search", MS_NOT_AVAILABLE);
48  SG_ADD(&m_gap_threshold, "gap_threshold", "Gap threshold", MS_NOT_AVAILABLE);
49  SG_ADD(&m_ell, "ell", "Average loss", MS_NOT_AVAILABLE);
50 
51  m_lambda = 1.0;
52  m_num_iter = 50;
53  m_do_line_search = true;
54  m_gap_threshold = 0.1;
55  m_ell = 0;
56 }
57 
59 {
60 }
61 
63 {
64  return CT_FWSOSVM;
65 }
66 
68 {
69  SG_DEBUG("Entering CFWSOSVM::train_machine.\n");
70  if (data)
71  set_features(data);
72 
73  // Initialize the model for training
75  // Check that the scenary is correct to start with training
77  SG_DEBUG("The training setup is correct.\n");
78 
79  // Dimensionality of the joint feature space
80  int32_t M = m_model->get_dim();
81  // Number of training examples
83 
84  SG_DEBUG("M=%d, N =%d.\n", M, N);
85 
86  // Initialize the weight vector
88  m_w.zero();
89 
90  // Initialize the average loss
91  m_ell = 0;
92 
93  // logging
94  if (m_verbose)
95  {
96  if (m_helper != NULL)
98 
99  m_helper = new CSOSVMHelper();
100  SG_REF(m_helper);
101  }
102 
103  // Main loop
104  int32_t k = 0;
105  SGVector<float64_t> w_s(M);
106  float64_t ell_s = 0;
107  for (int32_t pi = 0; pi < m_num_iter; ++pi)
108  {
109  // init w_s and ell_s
110  k = pi;
111  w_s.zero();
112  ell_s = 0;
113 
114  for (int32_t si = 0; si < N; ++si)
115  {
116  // 1) solve the loss-augmented inference for point si
117  CResultSet* result = m_model->argmax(m_w, si);
118 
119  // 2) get the subgradient
120  // psi_i(y) := phi(x_i,y_i) - phi(x_i, y_pred)
121  SGVector<float64_t> psi_i(M);
122  if (result->psi_computed)
123  {
125  1.0, result->psi_truth.vector, -1.0, result->psi_pred.vector,
126  psi_i.vlen);
127  }
128  else if(result->psi_computed_sparse)
129  {
130  psi_i.zero();
131  result->psi_pred_sparse.add_to_dense(1.0, psi_i.vector, psi_i.vlen);
132  result->psi_truth_sparse.add_to_dense(-1.0, psi_i.vector, psi_i.vlen);
133  }
134  else
135  {
136  SG_ERROR("model(%s) should have either of psi_computed or psi_computed_sparse"
137  "to be set true\n", m_model->get_name());
138  }
139 
140  // 3) loss_i = L(y_i, y_pred)
141  float64_t loss_i = result->delta;
142  ASSERT(loss_i - CMath::dot(m_w.vector, psi_i.vector, m_w.vlen) >= -1e-12);
143 
144  // 4) update w_s and ell_s
145  w_s.add(psi_i);
146  ell_s += loss_i;
147 
148  SG_UNREF(result);
149 
150  } // end si
151 
152  w_s.scale(1.0 / (N*m_lambda));
153  ell_s /= N;
154 
155  // 5) duality gap
156  SGVector<float64_t> w_diff = m_w.clone();
157  SGVector<float64_t>::add(w_diff.vector, 1.0, m_w.vector, -1.0, w_s.vector, w_s.vlen);
158  float64_t dual_gap = m_lambda * CMath::dot(m_w.vector, w_diff.vector, m_w.vlen) - m_ell + ell_s;
159 
160  // Debug: compute primal and dual objectives and training error
161  if (m_verbose)
162  {
164  float64_t dual = CSOSVMHelper::dual_objective(m_w, m_ell, m_lambda);
165  ASSERT(CMath::fequals_abs(primal - dual, dual_gap, 1e-12));
166  float64_t train_error = CSOSVMHelper::average_loss(m_w, m_model); // Note train_error isn't ell_s
167 
168  SG_SPRINT("pass %d (iteration %d), primal = %f, dual = %f, duality gap = %f, train_error = %f \n",
169  pi, k, primal, dual, dual_gap, train_error);
170 
171  m_helper->add_debug_info(primal, (1.0*k) / N, train_error, dual, dual_gap);
172  }
173 
174  // 6) check duality gap
175  if (dual_gap <= m_gap_threshold)
176  {
177  SG_DEBUG("iteration %d...\n", k);
178  SG_DEBUG("current gap: %f, gap_threshold: %f\n", dual_gap, m_gap_threshold);
179  SG_DEBUG("Duality gap below threshold -- stopping!\n");
180  break; // stop main loop
181  }
182  else
183  {
184  SG_DEBUG("iteration %d...\n", k);
185  SG_DEBUG("current gap: %f.\n", dual_gap);
186  }
187 
188  // 7) step-size gamma
189  float64_t gamma = 1.0 / (k+1.0);
190  if (m_do_line_search)
191  {
192  gamma = dual_gap / (m_lambda \
193  * (CMath::dot(w_diff.vector, w_diff.vector, w_diff.vlen) + 1e-12));
194  gamma = ((gamma > 1 ? 1 : gamma) < 0) ? 0 : gamma; // clip to [0,1], or max(0,min(1,gamma))
195  }
196 
197  // 8) finally update w and ell
198  SGVector<float64_t>::add(m_w.vector, 1.0-gamma, m_w.vector, gamma, w_s.vector, m_w.vlen);
199  m_ell = (1.0-gamma) * m_ell + gamma * ell_s;
200 
201  } // end pi
202 
203  if (m_verbose)
204  m_helper->terminate();
205 
206  SG_DEBUG("Leaving CFWSOSVM::train_machine.\n");
207  return true;
208 }
209 
211 {
212  return m_lambda;
213 }
214 
216 {
217  m_lambda = lbda;
218 }
219 
220 int32_t CFWSOSVM::get_num_iter() const
221 {
222  return m_num_iter;
223 }
224 
225 void CFWSOSVM::set_num_iter(int32_t num_iter)
226 {
227  m_num_iter = num_iter;
228 }
229 
231 {
232  return m_gap_threshold;
233 }
234 
236 {
237  m_gap_threshold = gap_threshold;
238 }
239 
241 {
242  return m_ell;
243 }
244 
246 {
247  m_ell = ell;
248 }
249 
SGVector< float64_t > psi_truth
EMachineType
Definition: Machine.h:33
Base class of the labels used in Structured Output (SO) problems.
void set_ell(float64_t ell)
Definition: FWSOSVM.cpp:245
CLabels * m_labels
Definition: Machine.h:361
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
static float64_t primal_objective(SGVector< float64_t > w, CStructuredModel *model, float64_t lbda)
Definition: SOSVMHelper.cpp:57
virtual int32_t get_dim() const =0
static bool fequals_abs(const T &a, const T &b, const float64_t eps)
Definition: Math.h:318
void scale(T alpha)
Scale vector inplace.
Definition: SGVector.cpp:843
#define SG_REF(x)
Definition: SGObject.h:51
float64_t get_lambda() const
Definition: FWSOSVM.cpp:210
int32_t get_num_iter() const
Definition: FWSOSVM.cpp:220
virtual bool train_machine(CFeatures *data=NULL)
Definition: FWSOSVM.cpp:67
index_t vlen
Definition: SGVector.h:494
#define SG_SPRINT(...)
Definition: SGIO.h:180
void set_lambda(float64_t lbda)
Definition: FWSOSVM.cpp:215
#define ASSERT(x)
Definition: SGIO.h:201
float64_t get_gap_threshold() const
Definition: FWSOSVM.cpp:230
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
double float64_t
Definition: common.h:50
static float64_t dual_objective(SGVector< float64_t > w, float64_t aloss, float64_t lbda)
Definition: SOSVMHelper.cpp:83
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
Definition: SOSVMHelper.h:31
void set_gap_threshold(float64_t gap_threshold)
Definition: FWSOSVM.cpp:235
float64_t get_ell() const
Definition: FWSOSVM.cpp:240
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
Definition: Math.h:627
virtual bool check_training_setup() const
static float64_t average_loss(SGVector< float64_t > w, CStructuredModel *model, bool is_ub=false)
Definition: SOSVMHelper.cpp:88
Class CStructuredModel that represents the application specific model and contains most of the applic...
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual const char * get_name() const
Definition: FWSOSVM.h:46
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
virtual void add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, float64_t dual=-1, float64_t dgap=-1)
virtual int32_t get_num_labels() const
The class Features is the base class of all feature objects.
Definition: Features.h:68
SGVector< T > clone() const
Definition: SGVector.cpp:209
void set_num_iter(int32_t num_iter)
Definition: FWSOSVM.cpp:225
SGVector< float64_t > psi_pred
SGSparseVector< float64_t > psi_truth_sparse
static CStructuredLabels * to_structured(CLabels *base_labels)
#define SG_ADD(...)
Definition: SGObject.h:81
virtual const char * get_name() const
virtual EMachineType get_classifier_type()
Definition: FWSOSVM.cpp:62
SGSparseVector< float64_t > psi_pred_sparse
void add(const SGVector< T > x)
Definition: SGVector.cpp:281

SHOGUN Machine Learning Toolbox - Documentation