15 using namespace shogun;
20 this->fun_obj=
const_cast<function *
>(f);
32 float64_t eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
35 float64_t sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4.;
39 float64_t alpha, f, fnew, prered, actred, gs;
42 int n = (int) fun_obj->get_nr_variable();
43 int search = 1, iter = 1, inc = 1;
54 delta = cblas_dnrm2(n, g, inc);
58 if (gnorm <= eps*gnorm1)
63 CSignal::clear_cancel();
66 while (iter <= max_iter && search && (!CSignal::cancel_computations()))
68 if (max_train_time > 0 && start_time.
cur_time_diff() > max_train_time)
71 cg_iter = trcg(delta, g, s, r);
74 cblas_daxpy(n, one, s, inc, w_new, inc);
76 gs = cblas_ddot(n, g, inc, s, inc);
77 prered = -0.5*(gs-cblas_ddot(n, s, inc, r, inc));
78 fnew = fun_obj->fun(w_new);
84 snorm = cblas_dnrm2(n, s, inc);
86 delta = CMath::min(delta, snorm);
89 if (fnew - f - gs <= 0)
92 alpha = CMath::max(sigma1, -0.5*(gs/(fnew - f - gs)));
95 if (actred < eta0*prered)
96 delta = CMath::min(CMath::max(alpha, sigma1)*snorm, sigma2*delta);
97 else if (actred < eta1*prered)
98 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma2*delta));
99 else if (actred < eta2*prered)
100 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma3*delta));
102 delta = CMath::max(delta, CMath::min(alpha*snorm, sigma3*delta));
104 SG_INFO(
"iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
106 if (actred > eta0*prered)
113 gnorm = cblas_dnrm2(n, g, inc);
114 if (gnorm < eps*gnorm1)
116 SG_SABS_PROGRESS(gnorm, -CMath::log10(gnorm), -CMath::log10(1), -CMath::log10(eps*gnorm1), 6);
123 if (CMath::abs(actred) <= 0 && CMath::abs(prered) <= 0)
128 if (CMath::abs(actred) <= 1.0e-12*CMath::abs(f) &&
129 CMath::abs(prered) <= 1.0e-12*CMath::abs(f))
144 int32_t CTron::trcg(
float64_t delta,
double* g,
double* s,
double* r)
148 int n = (int) fun_obj->get_nr_variable();
153 double rTr, rnewTrnew, alpha, beta, cgtol;
161 cgtol = 0.1* cblas_dnrm2(n, g, inc);
164 rTr = cblas_ddot(n, r, inc, r, inc);
167 if (cblas_dnrm2(n, r, inc) <= cgtol)
172 alpha = rTr/cblas_ddot(n, d, inc, Hd, inc);
173 cblas_daxpy(n, alpha, d, inc, s, inc);
174 if (cblas_dnrm2(n, s, inc) > delta)
176 SG_INFO(
"cg reaches trust region boundary\n");
178 cblas_daxpy(n, alpha, d, inc, s, inc);
180 double std = cblas_ddot(n, s, inc, d, inc);
181 double sts = cblas_ddot(n, s, inc, s, inc);
182 double dtd = cblas_ddot(n, d, inc, d, inc);
183 double dsq = delta*
delta;
184 double rad = sqrt(std*std + dtd*(dsq-sts));
186 alpha = (dsq - sts)/(std + rad);
188 alpha = (rad - std)/dtd;
189 cblas_daxpy(n, alpha, d, inc, s, inc);
191 cblas_daxpy(n, alpha, Hd, inc, r, inc);
195 cblas_daxpy(n, alpha, Hd, inc, r, inc);
196 rnewTrnew = cblas_ddot(n, r, inc, r, inc);
197 beta = rnewTrnew/rTr;
198 cblas_dscal(n, beta, d, inc);
199 cblas_daxpy(n, one, r, inc, d, inc);
212 for (int32_t i=1; i<n; i++)
213 if (CMath::abs(x[i]) >= dmax)
214 dmax = CMath::abs(x[i]);