SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
slep_mc_tree_lr.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Sergey Lisitsyn
8  * Copyright (C) 2010-2012 Jun Liu, Jieping Ye
9  */
10 
12 #ifdef HAVE_EIGEN3
18 #include <shogun/lib/Signal.h>
19 #include <shogun/lib/Time.h>
20 #include <iostream>
21 
22 using namespace shogun;
23 using namespace Eigen;
24 using namespace std;
25 
26 namespace shogun
27 {
28 
29 slep_result_t slep_mc_tree_lr(
30  CDotFeatures* features,
31  CMulticlassLabels* labels,
32  float64_t z,
33  const slep_options& options)
34 {
35  int i,j;
36  // obtain problem parameters
37  int n_feats = features->get_dim_feature_space();
38  int n_vecs = features->get_num_vectors();
39  int n_classes = labels->get_num_classes();
40 
41  // labels vector containing values in range (0 .. n_classes)
42  SGVector<float64_t> labels_vector = labels->get_labels();
43 
44  // initialize matrices and vectors to be used
45  // weight vector
46  MatrixXd w = MatrixXd::Zero(n_feats, n_classes);
47  // intercepts (biases)
48  VectorXd c = VectorXd::Zero(n_classes);
49 
50  if (options.last_result)
51  {
52  SGMatrix<float64_t> last_w = options.last_result->w;
53  SGVector<float64_t> last_c = options.last_result->c;
54  for (i=0; i<n_classes; i++)
55  {
56  c[i] = last_c[i];
57  for (j=0; j<n_feats; j++)
58  w(j,i) = last_w(j,i);
59  }
60  }
61  // iterative process matrices and vectors
62  MatrixXd wp = w, wwp = MatrixXd::Zero(n_feats, n_classes);
63  VectorXd cp = c, ccp = VectorXd::Zero(n_classes);
64  // search point weight vector
65  MatrixXd search_w = MatrixXd::Zero(n_feats, n_classes);
66  // search point intercepts
67  VectorXd search_c = VectorXd::Zero(n_classes);
68  // dot products
69  MatrixXd Aw = MatrixXd::Zero(n_vecs, n_classes);
70  for (j=0; j<n_classes; j++)
71  features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
72  MatrixXd As = MatrixXd::Zero(n_vecs, n_classes);
73  MatrixXd Awp = MatrixXd::Zero(n_vecs, n_classes);
74  // gradients
75  MatrixXd g = MatrixXd::Zero(n_feats, n_classes);
76  VectorXd gc = VectorXd::Zero(n_classes);
77  // projection
78  MatrixXd v = MatrixXd::Zero(n_feats, n_classes);
79 
80  // Lipschitz continuous gradient parameter for line search
81  double L = 1.0/(n_vecs*n_classes);
82  // coefficients for search point computation
83  double alphap = 0, alpha = 1;
84 
85  // lambda regularization parameter
86  double lambda = z;
87  // objective values
88  double objective = 0.0;
89  double objective_p = 0.0;
90 
91  int iter = 0;
92  bool done = false;
93  CTime time;
94  //internal::set_is_malloc_allowed(false);
95  while ((!done) && (iter<options.max_iter) && (!CSignal::cancel_computations()))
96  {
97  double beta = (alphap-1)/alpha;
98  // compute search points
99  search_w = w + beta*wwp;
100  search_c = c + beta*ccp;
101 
102  // update dot products with search point
103  As = Aw + beta*(Aw-Awp);
104 
105  // compute objective and gradient at search point
106  double fun_s = 0;
107  g.setZero();
108  gc.setZero();
109  // for each vector
110  for (i=0; i<n_vecs; i++)
111  {
112  // class of current vector
113  int vec_class = labels_vector[i];
114  // for each class
115  for (j=0; j<n_classes; j++)
116  {
117  // compute logistic loss
118  double aa = ((vec_class == j) ? -1.0 : 1.0)*(As(i,j) + search_c(j));
119  double bb = aa > 0.0 ? aa : 0.0;
120  // avoid underflow via log-sum-exp trick
121  fun_s += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
122  double prob = 1.0/(1+CMath::exp(aa));
123  double b = ((vec_class == j) ? -1.0 : 1.0)*(1-prob);
124  // update gradient of intercepts
125  gc[j] += b;
126  // update gradient of weight vectors
127  features->add_to_dense_vec(b, i, g.col(j).data(), n_feats);
128  }
129  }
130  //fun_s /= (n_vecs*n_classes);
131 
132  wp = w;
133  Awp = Aw;
134  cp = c;
135 
136  int inner_iter = 0;
137  double fun_x = 0;
138 
139  // line search process
140  while (inner_iter<5000)
141  {
142  // compute line search point
143  v = search_w - g/L;
144  c = search_c - gc/L;
145 
146  // compute projection of gradient
147  if (options.general)
148  general_altra_mt(w.data(),v.data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes,lambda/L);
149  else
150  altra_mt(w.data(),v.data(),n_classes,n_feats,options.ind_t,options.n_nodes,lambda/L);
151  v = w - search_w;
152 
153  // update dot products
154  for (j=0; j<n_classes; j++)
155  features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
156 
157  // compute objective at search point
158  fun_x = 0;
159  for (i=0; i<n_vecs; i++)
160  {
161  int vec_class = labels_vector[i];
162  for (j=0; j<n_classes; j++)
163  {
164  double aa = ((vec_class == j) ? -1.0 : 1.0)*(Aw(i,j) + c(j));
165  double bb = aa > 0.0 ? aa : 0.0;
166  fun_x += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
167  }
168  }
169  //fun_x /= (n_vecs*n_classes);
170 
171  // check for termination of line search
172  double r_sum = (v.squaredNorm() + (c-search_c).squaredNorm())/2;
173  double l_sum = fun_x - fun_s - v.cwiseProduct(g).sum() - (c-search_c).dot(gc);
174 
175  // stop if projected gradient is less than 1e-20
176  if (r_sum <= 1e-20)
177  {
178  SG_SINFO("Gradient step makes little improvement (%f)\n",r_sum)
179  done = true;
180  break;
181  }
182 
183  if (l_sum <= r_sum*L)
184  break;
185  else
186  L = CMath::max(2*L, l_sum/r_sum);
187 
188  inner_iter++;
189  }
190 
191  // update alpha coefficients
192  alphap = alpha;
193  alpha = (1+CMath::sqrt(4*alpha*alpha+1))/2;
194 
195  // update wwp and ccp
196  wwp = w - wp;
197  ccp = c - cp;
198 
199  // update objectives
200  objective_p = objective;
201  objective = fun_x;
202 
203  // compute tree norm
204  double tree_norm = 0.0;
205  if (options.general)
206  {
207  for (i=0; i<n_classes; i++)
208  tree_norm += general_treeNorm(w.col(i).data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes);
209  }
210  else
211  {
212  for (i=0; i<n_classes; i++)
213  tree_norm += treeNorm(w.col(i).data(),n_classes,n_feats,options.ind_t,options.n_nodes);
214  }
215 
216  // regularize objective with tree norm
217  objective += lambda*tree_norm;
218 
219  //cout << "Objective = " << objective << endl;
220 
221  // check for termination of whole process
222  if ((CMath::abs(objective - objective_p) < options.tolerance*CMath::abs(objective_p)) && (iter>2))
223  {
224  SG_SINFO("Objective changes less than tolerance\n")
225  done = true;
226  }
227 
228  iter++;
229  }
230  SG_SINFO("%d iterations passed, objective = %f\n",iter,objective)
231  //internal::set_is_malloc_allowed(true);
232 
233  // output computed weight vectors and intercepts
234  SGMatrix<float64_t> r_w(n_feats,n_classes);
235  for (j=0; j<n_classes; j++)
236  {
237  for (i=0; i<n_feats; i++)
238  r_w(i,j) = w(i,j);
239  }
240  //r_w.display_matrix();
241  SGVector<float64_t> r_c(n_classes);
242  for (j=0; j<n_classes; j++)
243  r_c[j] = c[j];
244  return slep_result_t(r_w, r_c);
245 };
246 };
247 #endif
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:47
slep_result_t slep_mc_tree_lr(CDotFeatures *features, CMulticlassLabels *labels, float64_t z, const slep_options &options)
double treeNorm(double *x, int ldx, int n, double *ind, int nodes)
Definition: altra.cpp:143
virtual void dense_dot_range(float64_t *output, int32_t start, int32_t stop, float64_t *alphas, float64_t *vec, int32_t dim, float64_t b)
Definition: DotFeatures.cpp:67
Vector::Scalar dot(Vector a, Vector b)
Definition: Redux.h:56
Definition: SGMatrix.h:20
virtual int32_t get_num_vectors() const =0
virtual void add_to_dense_vec(float64_t alpha, int32_t vec_idx1, float64_t *vec2, int32_t vec2_len, bool abs_val=false)=0
Features that support dot products among other operations.
Definition: DotFeatures.h:44
virtual int32_t get_dim_feature_space() const =0
SGVector< float64_t > get_labels()
Definition: DenseLabels.cpp:82
Multiclass Labels for multi-class classification.
double general_treeNorm(double *x, int ldx, int n, double *G, double *ind, int nodes)
void general_altra_mt(double *X, double *V, int n, int k, double *G, double *ind, int nodes, double mult)
double float64_t
Definition: common.h:50
static T max(T a, T b)
Definition: Math.h:168
static bool cancel_computations()
Definition: Signal.h:86
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
static float64_t exp(float64_t x)
Definition: Math.h:621
#define SG_SINFO(...)
Definition: SGIO.h:173
static float64_t log(float64_t v)
Definition: Math.h:922
static float32_t sqrt(float32_t x)
Definition: Math.h:459
void altra_mt(double *X, double *V, int n, int k, double *ind, int nodes, double mult)
Definition: altra.cpp:92
static T abs(T a)
Definition: Math.h:179

SHOGUN 机器学习工具包 - 项目文档