16 double tron_ddot(
const int N,
const double *X,
const int incX,
const double *Y,
const int incY)
19 return cblas_ddot(N,X,incX,Y,incY);
22 for (int32_t i=0; i<N; i++)
23 dot += X[incX*i]*Y[incY*i];
28 double tron_dnrm2(
const int N,
const double *X,
const int incX)
31 return cblas_dnrm2(N,X,incX);
34 for (int32_t i=0; i<N; i++)
35 dot += X[incX*i]*X[incX*i];
40 void tron_dscal(
const int N,
const double alpha,
double *X,
const int incX)
43 return cblas_dscal(N,alpha,X,incX);
45 for (int32_t i=0; i<N; i++)
50 void tron_daxpy(
const int N,
const double alpha,
const double *X,
const int incX,
double *Y,
const int incY)
53 cblas_daxpy(N,alpha,X,incX,Y,incY);
55 for (int32_t i=0; i<N; i++)
63 this->fun_obj=
const_cast<function *
>(f);
75 float64_t eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
78 float64_t sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4.;
82 float64_t alpha, f, fnew, prered, actred, gs;
85 int n = (int) fun_obj->get_nr_variable();
86 int search = 1, iter = 1, inc = 1;
87 double *s = SG_MALLOC(
double, n);
88 double *r = SG_MALLOC(
double, n);
89 double *w_new = SG_MALLOC(
double, n);
90 double *g = SG_MALLOC(
double, n);
101 if (gnorm <= eps*gnorm1)
106 CSignal::clear_cancel();
109 while (iter <= max_iter && search && (!CSignal::cancel_computations()))
111 if (max_train_time > 0 && start_time.
cur_time_diff() > max_train_time)
114 cg_iter = trcg(delta, g, s, r);
120 prered = -0.5*(gs-
tron_ddot(n, s, inc, r, inc));
121 fnew = fun_obj->fun(w_new);
129 delta = CMath::min(delta, snorm);
132 if (fnew - f - gs <= 0)
135 alpha =
CMath::max(sigma1, -0.5*(gs/(fnew - f - gs)));
138 if (actred < eta0*prered)
139 delta = CMath::min(
CMath::max(alpha, sigma1)*snorm, sigma2*delta);
140 else if (actred < eta1*prered)
141 delta =
CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma2*delta));
142 else if (actred < eta2*prered)
143 delta =
CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma3*delta));
145 delta =
CMath::max(delta, CMath::min(alpha*snorm, sigma3*delta));
147 SG_INFO(
"iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter)
149 if (actred > eta0*prered)
157 if (gnorm < eps*gnorm1)
159 SG_SABS_PROGRESS(gnorm, -CMath::log10(gnorm), -CMath::log10(1), -CMath::log10(eps*gnorm1), 6)
166 if (CMath::abs(actred) <= 0 && CMath::abs(prered) <= 0)
171 if (CMath::abs(actred) <= 1.0e-12*CMath::abs(f) &&
172 CMath::abs(prered) <= 1.0e-12*CMath::abs(f))
187 int32_t CTron::trcg(
float64_t delta,
double* g,
double* s,
double* r)
191 int n = (int) fun_obj->get_nr_variable();
194 double *Hd = SG_MALLOC(
double, n);
195 double *d = SG_MALLOC(
double, n);
196 double rTr, rnewTrnew, alpha, beta, cgtol;
215 alpha = rTr/
tron_ddot(n, d, inc, Hd, inc);
219 SG_INFO(
"cg reaches trust region boundary\n")
224 double sts =
tron_ddot(n, s, inc, s, inc);
225 double dtd =
tron_ddot(n, d, inc, d, inc);
226 double dsq = delta*delta;
227 double rad = sqrt(
std*
std + dtd*(dsq-sts));
229 alpha = (dsq - sts)/(
std + rad);
231 alpha = (rad -
std)/dtd;
239 rnewTrnew =
tron_ddot(n, r, inc, r, inc);
240 beta = rnewTrnew/rTr;
255 for (int32_t i=1; i<n; i++)
256 if (CMath::abs(x[i]) >= dmax)
257 dmax = CMath::abs(x[i]);
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Vector::Scalar dot(Vector a, Vector b)
void tron(float64_t *w, float64_t max_train_time)
double tron_ddot(const int N, const double *X, const int incX, const double *Y, const int incY)
float64_t cur_time_diff(bool verbose=false)
Class SGObject is the base class of all shogun objects.
void tron_daxpy(const int N, const double alpha, const double *X, const int incX, double *Y, const int incY)
void tron_dscal(const int N, const double alpha, double *X, const int incX)
all of classes and functions are contained in the shogun namespace
Matrix::Scalar max(Matrix m)
double tron_dnrm2(const int N, const double *X, const int incX)
#define SG_SABS_PROGRESS(...)