SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
tron.cpp
Go to the documentation of this file.
1 #include <math.h>
2 #include <stdio.h>
3 #include <string.h>
4 #include <stdarg.h>
5 
6 #include <shogun/lib/config.h>
7 #include <shogun/lib/Signal.h>
8 #include <shogun/lib/Time.h>
9 
13 
14 using namespace shogun;
15 
16 double tron_ddot(const int N, const double *X, const int incX, const double *Y, const int incY)
17 {
18 #ifdef HAVE_LAPACK
19  return cblas_ddot(N,X,incX,Y,incY);
20 #else
21  double dot = 0.0;
22  for (int32_t i=0; i<N; i++)
23  dot += X[incX*i]*Y[incY*i];
24  return dot;
25 #endif
26 }
27 
28 double tron_dnrm2(const int N, const double *X, const int incX)
29 {
30 #ifdef HAVE_LAPACK
31  return cblas_dnrm2(N,X,incX);
32 #else
33  double dot = 0.0;
34  for (int32_t i=0; i<N; i++)
35  dot += X[incX*i]*X[incX*i];
36  return sqrt(dot);
37 #endif
38 }
39 
40 void tron_dscal(const int N, const double alpha, double *X, const int incX)
41 {
42 #ifdef HAVE_LAPACK
43  return cblas_dscal(N,alpha,X,incX);
44 #else
45  for (int32_t i=0; i<N; i++)
46  X[i]*= alpha;
47 #endif
48 }
49 
50 void tron_daxpy(const int N, const double alpha, const double *X, const int incX, double *Y, const int incY)
51 {
52 #ifdef HAVE_LAPACK
53  cblas_daxpy(N,alpha,X,incX,Y,incY);
54 #else
55  for (int32_t i=0; i<N; i++)
56  Y[i] += alpha*X[i];
57 #endif
58 }
59 
60 CTron::CTron(const function *f, float64_t e, int32_t it)
61 : CSGObject()
62 {
63  this->fun_obj=const_cast<function *>(f);
64  this->eps=e;
65  this->max_iter=it;
66 }
67 
69 {
70 }
71 
72 void CTron::tron(float64_t *w, float64_t max_train_time)
73 {
74  // Parameters for updating the iterates.
75  float64_t eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
76 
77  // Parameters for updating the trust region size delta.
78  float64_t sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4.;
79 
80  int32_t i, cg_iter;
81  float64_t delta, snorm, one=1.0;
82  float64_t alpha, f, fnew, prered, actred, gs;
83 
84  /* calling external lib */
85  int n = (int) fun_obj->get_nr_variable();
86  int search = 1, iter = 1, inc = 1;
87  double *s = SG_MALLOC(double, n);
88  double *r = SG_MALLOC(double, n);
89  double *w_new = SG_MALLOC(double, n);
90  double *g = SG_MALLOC(double, n);
91 
92  for (i=0; i<n; i++)
93  w[i] = 0;
94 
95  f = fun_obj->fun(w);
96  fun_obj->grad(w, g);
97  delta = tron_dnrm2(n, g, inc);
98  float64_t gnorm1 = delta;
99  float64_t gnorm = gnorm1;
100 
101  if (gnorm <= eps*gnorm1)
102  search = 0;
103 
104  iter = 1;
105 
106  CSignal::clear_cancel();
107  CTime start_time;
108 
109  while (iter <= max_iter && search && (!CSignal::cancel_computations()))
110  {
111  if (max_train_time > 0 && start_time.cur_time_diff() > max_train_time)
112  break;
113 
114  cg_iter = trcg(delta, g, s, r);
115 
116  memcpy(w_new, w, sizeof(float64_t)*n);
117  tron_daxpy(n, one, s, inc, w_new, inc);
118 
119  gs = tron_ddot(n, g, inc, s, inc);
120  prered = -0.5*(gs-tron_ddot(n, s, inc, r, inc));
121  fnew = fun_obj->fun(w_new);
122 
123  // Compute the actual reduction.
124  actred = f - fnew;
125 
126  // On the first iteration, adjust the initial step bound.
127  snorm = tron_dnrm2(n, s, inc);
128  if (iter == 1)
129  delta = CMath::min(delta, snorm);
130 
131  // Compute prediction alpha*snorm of the step.
132  if (fnew - f - gs <= 0)
133  alpha = sigma3;
134  else
135  alpha = CMath::max(sigma1, -0.5*(gs/(fnew - f - gs)));
136 
137  // Update the trust region bound according to the ratio of actual to predicted reduction.
138  if (actred < eta0*prered)
139  delta = CMath::min(CMath::max(alpha, sigma1)*snorm, sigma2*delta);
140  else if (actred < eta1*prered)
141  delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma2*delta));
142  else if (actred < eta2*prered)
143  delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma3*delta));
144  else
145  delta = CMath::max(delta, CMath::min(alpha*snorm, sigma3*delta));
146 
147  SG_INFO("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter)
148 
149  if (actred > eta0*prered)
150  {
151  iter++;
152  memcpy(w, w_new, sizeof(float64_t)*n);
153  f = fnew;
154  fun_obj->grad(w, g);
155 
156  gnorm = tron_dnrm2(n, g, inc);
157  if (gnorm < eps*gnorm1)
158  break;
159  SG_SABS_PROGRESS(gnorm, -CMath::log10(gnorm), -CMath::log10(1), -CMath::log10(eps*gnorm1), 6)
160  }
161  if (f < -1.0e+32)
162  {
163  SG_WARNING("f < -1.0e+32\n")
164  break;
165  }
166  if (CMath::abs(actred) <= 0 && CMath::abs(prered) <= 0)
167  {
168  SG_WARNING("actred and prered <= 0\n")
169  break;
170  }
171  if (CMath::abs(actred) <= 1.0e-12*CMath::abs(f) &&
172  CMath::abs(prered) <= 1.0e-12*CMath::abs(f))
173  {
174  SG_WARNING("actred and prered too small\n")
175  break;
176  }
177  }
178 
179  SG_DONE()
180 
181  SG_FREE(g);
182  SG_FREE(r);
183  SG_FREE(w_new);
184  SG_FREE(s);
185 }
186 
187 int32_t CTron::trcg(float64_t delta, double* g, double* s, double* r)
188 {
189  /* calling external lib */
190  int i, cg_iter;
191  int n = (int) fun_obj->get_nr_variable();
192  int inc = 1;
193  double one = 1;
194  double *Hd = SG_MALLOC(double, n);
195  double *d = SG_MALLOC(double, n);
196  double rTr, rnewTrnew, alpha, beta, cgtol;
197 
198  for (i=0; i<n; i++)
199  {
200  s[i] = 0;
201  r[i] = -g[i];
202  d[i] = r[i];
203  }
204  cgtol = 0.1* tron_dnrm2(n, g, inc);
205 
206  cg_iter = 0;
207  rTr = tron_ddot(n, r, inc, r, inc);
208  while (1)
209  {
210  if (tron_dnrm2(n, r, inc) <= cgtol)
211  break;
212  cg_iter++;
213  fun_obj->Hv(d, Hd);
214 
215  alpha = rTr/tron_ddot(n, d, inc, Hd, inc);
216  tron_daxpy(n, alpha, d, inc, s, inc);
217  if (tron_dnrm2(n, s, inc) > delta)
218  {
219  SG_INFO("cg reaches trust region boundary\n")
220  alpha = -alpha;
221  tron_daxpy(n, alpha, d, inc, s, inc);
222 
223  double std = tron_ddot(n, s, inc, d, inc);
224  double sts = tron_ddot(n, s, inc, s, inc);
225  double dtd = tron_ddot(n, d, inc, d, inc);
226  double dsq = delta*delta;
227  double rad = sqrt(std*std + dtd*(dsq-sts));
228  if (std >= 0)
229  alpha = (dsq - sts)/(std + rad);
230  else
231  alpha = (rad - std)/dtd;
232  tron_daxpy(n, alpha, d, inc, s, inc);
233  alpha = -alpha;
234  tron_daxpy(n, alpha, Hd, inc, r, inc);
235  break;
236  }
237  alpha = -alpha;
238  tron_daxpy(n, alpha, Hd, inc, r, inc);
239  rnewTrnew = tron_ddot(n, r, inc, r, inc);
240  beta = rnewTrnew/rTr;
241  tron_dscal(n, beta, d, inc);
242  tron_daxpy(n, one, r, inc, d, inc);
243  rTr = rnewTrnew;
244  }
245 
246  SG_FREE(d);
247  SG_FREE(Hd);
248 
249  return(cg_iter);
250 }
251 
252 float64_t CTron::norm_inf(int32_t n, float64_t *x)
253 {
254  float64_t dmax = CMath::abs(x[0]);
255  for (int32_t i=1; i<n; i++)
256  if (CMath::abs(x[i]) >= dmax)
257  dmax = CMath::abs(x[i]);
258  return(dmax);
259 }
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:47
#define SG_INFO(...)
Definition: SGIO.h:118
#define SG_DONE()
Definition: SGIO.h:157
virtual ~CTron()
Definition: tron.cpp:68
Vector::Scalar dot(Vector a, Vector b)
Definition: Redux.h:58
void tron(float64_t *w, float64_t max_train_time)
Definition: tron.cpp:72
Definition: basetag.h:132
double tron_ddot(const int N, const double *X, const int incX, const double *Y, const int incY)
Definition: tron.cpp:16
class Tron
Definition: tron.h:55
float64_t cur_time_diff(bool verbose=false)
Definition: Time.cpp:68
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
void tron_daxpy(const int N, const double alpha, const double *X, const int incX, double *Y, const int incY)
Definition: tron.cpp:50
void tron_dscal(const int N, const double alpha, double *X, const int incX)
Definition: tron.cpp:40
double float64_t
Definition: common.h:50
CTron()
Definition: tron.h:58
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
Matrix::Scalar max(Matrix m)
Definition: Redux.h:68
#define SG_WARNING(...)
Definition: SGIO.h:128
double tron_dnrm2(const int N, const double *X, const int incX)
Definition: tron.cpp:28
#define SG_SABS_PROGRESS(...)
Definition: SGIO.h:188

SHOGUN Machine Learning Toolbox - Documentation