SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
slep_mc_tree_lr.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Sergey Lisitsyn
8  * Copyright (C) 2010-2012 Jun Liu, Jieping Ye
9  */
10 
12 #ifdef USE_GPL_SHOGUN
13 
19 #include <shogun/lib/Signal.h>
20 #include <shogun/lib/Time.h>
21 #include <iostream>
22 
23 using namespace shogun;
24 using namespace Eigen;
25 using namespace std;
26 
27 namespace shogun
28 {
29 
30 slep_result_t slep_mc_tree_lr(
31  CDotFeatures* features,
32  CMulticlassLabels* labels,
33  float64_t z,
34  const slep_options& options)
35 {
36  int i,j;
37  // obtain problem parameters
38  int n_feats = features->get_dim_feature_space();
39  int n_vecs = features->get_num_vectors();
40  int n_classes = labels->get_num_classes();
41 
42  // labels vector containing values in range (0 .. n_classes)
43  SGVector<float64_t> labels_vector = labels->get_labels();
44 
45  // initialize matrices and vectors to be used
46  // weight vector
47  MatrixXd w = MatrixXd::Zero(n_feats, n_classes);
48  // intercepts (biases)
49  VectorXd c = VectorXd::Zero(n_classes);
50 
51  if (options.last_result)
52  {
53  SGMatrix<float64_t> last_w = options.last_result->w;
54  SGVector<float64_t> last_c = options.last_result->c;
55  for (i=0; i<n_classes; i++)
56  {
57  c[i] = last_c[i];
58  for (j=0; j<n_feats; j++)
59  w(j,i) = last_w(j,i);
60  }
61  }
62  // iterative process matrices and vectors
63  MatrixXd wp = w, wwp = MatrixXd::Zero(n_feats, n_classes);
64  VectorXd cp = c, ccp = VectorXd::Zero(n_classes);
65  // search point weight vector
66  MatrixXd search_w = MatrixXd::Zero(n_feats, n_classes);
67  // search point intercepts
68  VectorXd search_c = VectorXd::Zero(n_classes);
69  // dot products
70  MatrixXd Aw = MatrixXd::Zero(n_vecs, n_classes);
71  for (j=0; j<n_classes; j++)
72  features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
73  MatrixXd As = MatrixXd::Zero(n_vecs, n_classes);
74  MatrixXd Awp = MatrixXd::Zero(n_vecs, n_classes);
75  // gradients
76  MatrixXd g = MatrixXd::Zero(n_feats, n_classes);
77  VectorXd gc = VectorXd::Zero(n_classes);
78  // projection
79  MatrixXd v = MatrixXd::Zero(n_feats, n_classes);
80 
81  // Lipschitz continuous gradient parameter for line search
82  double L = 1.0/(n_vecs*n_classes);
83  // coefficients for search point computation
84  double alphap = 0, alpha = 1;
85 
86  // lambda regularization parameter
87  double lambda = z;
88  // objective values
89  double objective = 0.0;
90  double objective_p = 0.0;
91 
92  int iter = 0;
93  bool done = false;
94  CTime time;
95  //internal::set_is_malloc_allowed(false);
96  while ((!done) && (iter<options.max_iter) && (!CSignal::cancel_computations()))
97  {
98  double beta = (alphap-1)/alpha;
99  // compute search points
100  search_w = w + beta*wwp;
101  search_c = c + beta*ccp;
102 
103  // update dot products with search point
104  As = Aw + beta*(Aw-Awp);
105 
106  // compute objective and gradient at search point
107  double fun_s = 0;
108  g.setZero();
109  gc.setZero();
110  // for each vector
111  for (i=0; i<n_vecs; i++)
112  {
113  // class of current vector
114  int vec_class = labels_vector[i];
115  // for each class
116  for (j=0; j<n_classes; j++)
117  {
118  // compute logistic loss
119  double aa = ((vec_class == j) ? -1.0 : 1.0)*(As(i,j) + search_c(j));
120  double bb = aa > 0.0 ? aa : 0.0;
121  // avoid underflow via log-sum-exp trick
122  fun_s += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
123  double prob = 1.0/(1+CMath::exp(aa));
124  double b = ((vec_class == j) ? -1.0 : 1.0)*(1-prob);
125  // update gradient of intercepts
126  gc[j] += b;
127  // update gradient of weight vectors
128  features->add_to_dense_vec(b, i, g.col(j).data(), n_feats);
129  }
130  }
131  //fun_s /= (n_vecs*n_classes);
132 
133  wp = w;
134  Awp = Aw;
135  cp = c;
136 
137  int inner_iter = 0;
138  double fun_x = 0;
139 
140  // line search process
141  while (inner_iter<5000)
142  {
143  // compute line search point
144  v = search_w - g/L;
145  c = search_c - gc/L;
146 
147  // compute projection of gradient
148  if (options.general)
149  general_altra_mt(w.data(),v.data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes,lambda/L);
150  else
151  altra_mt(w.data(),v.data(),n_classes,n_feats,options.ind_t,options.n_nodes,lambda/L);
152  v = w - search_w;
153 
154  // update dot products
155  for (j=0; j<n_classes; j++)
156  features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
157 
158  // compute objective at search point
159  fun_x = 0;
160  for (i=0; i<n_vecs; i++)
161  {
162  int vec_class = labels_vector[i];
163  for (j=0; j<n_classes; j++)
164  {
165  double aa = ((vec_class == j) ? -1.0 : 1.0)*(Aw(i,j) + c(j));
166  double bb = aa > 0.0 ? aa : 0.0;
167  fun_x += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
168  }
169  }
170  //fun_x /= (n_vecs*n_classes);
171 
172  // check for termination of line search
173  double r_sum = (v.squaredNorm() + (c-search_c).squaredNorm())/2;
174  double l_sum = fun_x - fun_s - v.cwiseProduct(g).sum() - (c-search_c).dot(gc);
175 
176  // stop if projected gradient is less than 1e-20
177  if (r_sum <= 1e-20)
178  {
179  SG_SINFO("Gradient step makes little improvement (%f)\n",r_sum)
180  done = true;
181  break;
182  }
183 
184  if (l_sum <= r_sum*L)
185  break;
186  else
187  L = CMath::max(2*L, l_sum/r_sum);
188 
189  inner_iter++;
190  }
191 
192  // update alpha coefficients
193  alphap = alpha;
194  alpha = (1+CMath::sqrt(4*alpha*alpha+1))/2;
195 
196  // update wwp and ccp
197  wwp = w - wp;
198  ccp = c - cp;
199 
200  // update objectives
201  objective_p = objective;
202  objective = fun_x;
203 
204  // compute tree norm
205  double tree_norm = 0.0;
206  if (options.general)
207  {
208  for (i=0; i<n_classes; i++)
209  tree_norm += general_treeNorm(w.col(i).data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes);
210  }
211  else
212  {
213  for (i=0; i<n_classes; i++)
214  tree_norm += treeNorm(w.col(i).data(),n_classes,n_feats,options.ind_t,options.n_nodes);
215  }
216 
217  // regularize objective with tree norm
218  objective += lambda*tree_norm;
219 
220  //cout << "Objective = " << objective << endl;
221 
222  // check for termination of whole process
223  if ((CMath::abs(objective - objective_p) < options.tolerance*CMath::abs(objective_p)) && (iter>2))
224  {
225  SG_SINFO("Objective changes less than tolerance\n")
226  done = true;
227  }
228 
229  iter++;
230  }
231  SG_SINFO("%d iterations passed, objective = %f\n",iter,objective)
232  //internal::set_is_malloc_allowed(true);
233 
234  // output computed weight vectors and intercepts
235  SGMatrix<float64_t> r_w(n_feats,n_classes);
236  for (j=0; j<n_classes; j++)
237  {
238  for (i=0; i<n_feats; i++)
239  r_w(i,j) = w(i,j);
240  }
241  //r_w.display_matrix();
242  SGVector<float64_t> r_c(n_classes);
243  for (j=0; j<n_classes; j++)
244  r_c[j] = c[j];
245  return slep_result_t(r_w, r_c);
246 };
247 };
248 
249 #endif //USE_GPL_SHOGUN
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:47
virtual void dense_dot_range(float64_t *output, int32_t start, int32_t stop, float64_t *alphas, float64_t *vec, int32_t dim, float64_t b)
Definition: DotFeatures.cpp:67
Vector::Scalar dot(Vector a, Vector b)
Definition: Redux.h:58
Definition: SGMatrix.h:20
virtual int32_t get_num_vectors() const =0
Definition: basetag.h:132
virtual void add_to_dense_vec(float64_t alpha, int32_t vec_idx1, float64_t *vec2, int32_t vec2_len, bool abs_val=false)=0
Features that support dot products among other operations.
Definition: DotFeatures.h:44
virtual int32_t get_dim_feature_space() const =0
SGVector< float64_t > get_labels()
Definition: DenseLabels.cpp:82
Multiclass Labels for multi-class classification.
double float64_t
Definition: common.h:50
static bool cancel_computations()
Definition: Signal.h:86
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
static float64_t exp(float64_t x)
Definition: Math.h:621
#define SG_SINFO(...)
Definition: SGIO.h:173
static float64_t log(float64_t v)
Definition: Math.h:922
Class which collects generic mathematical functions.
Definition: Math.h:134
Matrix::Scalar max(Matrix m)
Definition: Redux.h:68
static T abs(T a)
Definition: Math.h:179

SHOGUN Machine Learning Toolbox - Documentation