SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
StochasticSOSVM.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
14 #include <shogun/lib/SGVector.h>
15 
16 using namespace shogun;
17 
20 {
21  init();
22 }
23 
25  CStructuredModel* model,
26  CStructuredLabels* labs,
27  bool do_weighted_averaging,
28  bool verbose)
29 : CLinearStructuredOutputMachine(model, labs)
30 {
31  REQUIRE(model != NULL && labs != NULL,
32  "%s::CStochasticSOSVM(): model and labels cannot be NULL!\n", get_name());
33 
34  REQUIRE(labs->get_num_labels() > 0,
35  "%s::CStochasticSOSVM(): number of labels should be greater than 0!\n", get_name());
36 
37  init();
38  m_lambda = 1.0 / labs->get_num_labels();
39  m_do_weighted_averaging = do_weighted_averaging;
40  m_verbose = verbose;
41 }
42 
43 void CStochasticSOSVM::init()
44 {
45  SG_ADD(&m_lambda, "lambda", "Regularization constant", MS_NOT_AVAILABLE);
46  SG_ADD(&m_num_iter, "num_iter", "Number of iterations", MS_NOT_AVAILABLE);
47  SG_ADD(&m_do_weighted_averaging, "do_weighted_averaging", "Do weighted averaging", MS_NOT_AVAILABLE);
48  SG_ADD(&m_debug_multiplier, "debug_multiplier", "Debug multiplier", MS_NOT_AVAILABLE);
49  SG_ADD(&m_rand_seed, "rand_seed", "Random seed", MS_NOT_AVAILABLE);
50 
51  m_lambda = 1.0;
52  m_num_iter = 50;
53  m_do_weighted_averaging = true;
54  m_debug_multiplier = 0;
55  m_rand_seed = 1;
56 }
57 
59 {
60 }
61 
63 {
64  return CT_STOCHASTICSOSVM;
65 }
66 
68 {
69  SG_DEBUG("Entering CStochasticSOSVM::train_machine.\n");
70  if (data)
71  set_features(data);
72 
73  // Initialize the model for training
75  // Check that the scenary is correct to start with training
77  SG_DEBUG("The training setup is correct.\n");
78 
79  // Dimensionality of the joint feature space
80  int32_t M = m_model->get_dim();
81  // Number of training examples
83 
84  SG_DEBUG("M=%d, N =%d.\n", M, N);
85 
86  // Initialize the weight vector
88  m_w.zero();
89 
90  SGVector<float64_t> w_avg;
91  if (m_do_weighted_averaging)
92  w_avg = m_w.clone();
93 
94  // logging
95  if (m_verbose)
96  {
97  if (m_helper != NULL)
99 
100  m_helper = new CSOSVMHelper();
101  SG_REF(m_helper);
102  }
103 
104  int32_t debug_iter = 1;
105  if (m_debug_multiplier == 0)
106  {
107  debug_iter = N;
108  m_debug_multiplier = 100;
109  }
110 
111  CMath::init_random(m_rand_seed);
112 
113  // Main loop
114  int32_t k = 0;
115  for (int32_t pi = 0; pi < m_num_iter; ++pi)
116  {
117  for (int32_t si = 0; si < N; ++si)
118  {
119  // 1) Picking random example
120  int32_t i = CMath::random(0, N-1);
121 
122  // 2) solve the loss-augmented inference for point i
123  CResultSet* result = m_model->argmax(m_w, i);
124 
125  // 3) get the subgradient
126  // psi_i(y) := phi(x_i,y_i) - phi(x_i, y)
127  SGVector<float64_t> psi_i(M);
128  SGVector<float64_t> w_s(M);
129 
130  if (result->psi_computed)
131  {
133  1.0, result->psi_truth.vector, -1.0, result->psi_pred.vector,
134  psi_i.vlen);
135  }
136  else if(result->psi_computed_sparse)
137  {
138  psi_i.zero();
139  result->psi_pred_sparse.add_to_dense(1.0, psi_i.vector, psi_i.vlen);
140  result->psi_truth_sparse.add_to_dense(-1.0, psi_i.vector, psi_i.vlen);
141  }
142  else
143  {
144  SG_ERROR("model(%s) should have either of psi_computed or psi_computed_sparse"
145  "to be set true\n", m_model->get_name());
146  }
147 
148  w_s = psi_i.clone();
149  w_s.scale(1.0 / (N*m_lambda));
150 
151  // 4) step-size gamma
152  float64_t gamma = 1.0 / (k+1.0);
153 
154  // 5) finally update the weights
156  1.0-gamma, m_w.vector, gamma*N, w_s.vector, m_w.vlen);
157 
158  // 6) Optionally, update the weighted average
159  if (m_do_weighted_averaging)
160  {
161  float64_t rho = 2.0 / (k+2.0);
163  1.0-rho, w_avg.vector, rho, m_w.vector, w_avg.vlen);
164  }
165 
166  k += 1;
167  SG_UNREF(result);
168 
169  // Debug: compute objective and training error
170  if (m_verbose && k == debug_iter)
171  {
172  SGVector<float64_t> w_debug;
173  if (m_do_weighted_averaging)
174  w_debug = w_avg.clone();
175  else
176  w_debug = m_w.clone();
177 
178  float64_t primal = CSOSVMHelper::primal_objective(w_debug, m_model, m_lambda);
179  float64_t train_error = CSOSVMHelper::average_loss(w_debug, m_model);
180 
181  SG_DEBUG("pass %d (iteration %d), SVM primal = %f, train_error = %f \n",
182  pi, k, primal, train_error);
183 
184  m_helper->add_debug_info(primal, (1.0*k) / N, train_error);
185 
186  debug_iter = CMath::min(debug_iter+N, debug_iter*(1+m_debug_multiplier/100));
187  }
188  }
189  }
190 
191  if (m_do_weighted_averaging)
192  m_w = w_avg.clone();
193 
194  if (m_verbose)
195  m_helper->terminate();
196 
197  SG_DEBUG("Leaving CStochasticSOSVM::train_machine.\n");
198  return true;
199 }
200 
202 {
203  return m_lambda;
204 }
205 
207 {
208  m_lambda = lbda;
209 }
210 
212 {
213  return m_num_iter;
214 }
215 
216 void CStochasticSOSVM::set_num_iter(int32_t num_iter)
217 {
218  m_num_iter = num_iter;
219 }
220 
222 {
223  return m_debug_multiplier;
224 }
225 
227 {
228  m_debug_multiplier = multiplier;
229 }
230 
232 {
233  return m_rand_seed;
234 }
235 
236 void CStochasticSOSVM::set_rand_seed(uint32_t rand_seed)
237 {
238  m_rand_seed = rand_seed;
239 }
240 
SGVector< float64_t > psi_truth
EMachineType
Definition: Machine.h:33
Base class of the labels used in Structured Output (SO) problems.
uint32_t get_rand_seed() const
int32_t get_debug_multiplier() const
void set_debug_multiplier(int32_t multiplier)
CLabels * m_labels
Definition: Machine.h:361
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual const char * get_name() const
virtual bool train_machine(CFeatures *data=NULL)
static float64_t primal_objective(SGVector< float64_t > w, CStructuredModel *model, float64_t lbda)
Definition: SOSVMHelper.cpp:57
virtual int32_t get_dim() const =0
void set_num_iter(int32_t num_iter)
void scale(T alpha)
Scale vector inplace.
Definition: SGVector.cpp:843
#define SG_REF(x)
Definition: SGObject.h:51
static uint64_t random()
Definition: Math.h:1019
void set_rand_seed(uint32_t rand_seed)
index_t vlen
Definition: SGVector.h:494
static void init_random(uint32_t initseed=0)
Definition: Math.h:1006
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
double float64_t
Definition: common.h:50
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
Definition: SOSVMHelper.h:31
virtual bool check_training_setup() const
static float64_t average_loss(SGVector< float64_t > w, CStructuredModel *model, bool is_ub=false)
Definition: SOSVMHelper.cpp:88
Class CStructuredModel that represents the application specific model and contains most of the applic...
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
virtual void add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, float64_t dual=-1, float64_t dgap=-1)
virtual int32_t get_num_labels() const
The class Features is the base class of all feature objects.
Definition: Features.h:68
static T min(T a, T b)
Definition: Math.h:157
SGVector< T > clone() const
Definition: SGVector.cpp:209
void set_lambda(float64_t lbda)
SGVector< float64_t > psi_pred
virtual EMachineType get_classifier_type()
SGSparseVector< float64_t > psi_truth_sparse
static CStructuredLabels * to_structured(CLabels *base_labels)
#define SG_ADD(...)
Definition: SGObject.h:81
virtual const char * get_name() const
float64_t get_lambda() const
SGSparseVector< float64_t > psi_pred_sparse
void add(const SGVector< T > x)
Definition: SGVector.cpp:281
int32_t get_num_iter() const

SHOGUN 机器学习工具包 - 项目文档