00001 #include <math.h>
00002 #include <stdio.h>
00003 #include <string.h>
00004 #include <stdarg.h>
00005
00006 #include <shogun/lib/config.h>
00007 #include <shogun/lib/Signal.h>
00008 #include <shogun/lib/Time.h>
00009
00010 #ifdef HAVE_LAPACK
00011 #include <shogun/mathematics/Math.h>
00012 #include <shogun/classifier/svm/Tron.h>
00013
00014 using namespace shogun;
00015
00016 CTron::CTron(const function *f, float64_t e, int32_t it)
00017 : CSGObject()
00018 {
00019 this->fun_obj=const_cast<function *>(f);
00020 this->eps=e;
00021 this->max_iter=it;
00022 }
00023
00024 CTron::~CTron()
00025 {
00026 }
00027
00028 void CTron::tron(float64_t *w, float64_t max_train_time)
00029 {
00030
00031 float64_t eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
00032
00033
00034 float64_t sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4.;
00035
00036 int32_t i, cg_iter;
00037 float64_t delta, snorm, one=1.0;
00038 float64_t alpha, f, fnew, prered, actred, gs;
00039
00040
00041 int n = (int) fun_obj->get_nr_variable();
00042 int search = 1, iter = 1, inc = 1;
00043 double *s = SG_MALLOC(double, n);
00044 double *r = SG_MALLOC(double, n);
00045 double *w_new = SG_MALLOC(double, n);
00046 double *g = SG_MALLOC(double, n);
00047
00048 for (i=0; i<n; i++)
00049 w[i] = 0;
00050
00051 f = fun_obj->fun(w);
00052 fun_obj->grad(w, g);
00053 delta = cblas_dnrm2(n, g, inc);
00054 float64_t gnorm1 = delta;
00055 float64_t gnorm = gnorm1;
00056
00057 if (gnorm <= eps*gnorm1)
00058 search = 0;
00059
00060 iter = 1;
00061
00062 CSignal::clear_cancel();
00063 CTime start_time;
00064
00065 while (iter <= max_iter && search && (!CSignal::cancel_computations()))
00066 {
00067 if (max_train_time > 0 && start_time.cur_time_diff() > max_train_time)
00068 break;
00069
00070 cg_iter = trcg(delta, g, s, r);
00071
00072 memcpy(w_new, w, sizeof(float64_t)*n);
00073 cblas_daxpy(n, one, s, inc, w_new, inc);
00074
00075 gs = cblas_ddot(n, g, inc, s, inc);
00076 prered = -0.5*(gs-cblas_ddot(n, s, inc, r, inc));
00077 fnew = fun_obj->fun(w_new);
00078
00079
00080 actred = f - fnew;
00081
00082
00083 snorm = cblas_dnrm2(n, s, inc);
00084 if (iter == 1)
00085 delta = CMath::min(delta, snorm);
00086
00087
00088 if (fnew - f - gs <= 0)
00089 alpha = sigma3;
00090 else
00091 alpha = CMath::max(sigma1, -0.5*(gs/(fnew - f - gs)));
00092
00093
00094 if (actred < eta0*prered)
00095 delta = CMath::min(CMath::max(alpha, sigma1)*snorm, sigma2*delta);
00096 else if (actred < eta1*prered)
00097 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma2*delta));
00098 else if (actred < eta2*prered)
00099 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma3*delta));
00100 else
00101 delta = CMath::max(delta, CMath::min(alpha*snorm, sigma3*delta));
00102
00103 SG_INFO("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
00104
00105 if (actred > eta0*prered)
00106 {
00107 iter++;
00108 memcpy(w, w_new, sizeof(float64_t)*n);
00109 f = fnew;
00110 fun_obj->grad(w, g);
00111
00112 gnorm = cblas_dnrm2(n, g, inc);
00113 if (gnorm < eps*gnorm1)
00114 break;
00115 SG_SABS_PROGRESS(gnorm, -CMath::log10(gnorm), -CMath::log10(1), -CMath::log10(eps*gnorm1), 6);
00116 }
00117 if (f < -1.0e+32)
00118 {
00119 SG_WARNING("f < -1.0e+32\n");
00120 break;
00121 }
00122 if (CMath::abs(actred) <= 0 && CMath::abs(prered) <= 0)
00123 {
00124 SG_WARNING("actred and prered <= 0\n");
00125 break;
00126 }
00127 if (CMath::abs(actred) <= 1.0e-12*CMath::abs(f) &&
00128 CMath::abs(prered) <= 1.0e-12*CMath::abs(f))
00129 {
00130 SG_WARNING("actred and prered too small\n");
00131 break;
00132 }
00133 }
00134
00135 SG_DONE();
00136
00137 SG_FREE(g);
00138 SG_FREE(r);
00139 SG_FREE(w_new);
00140 SG_FREE(s);
00141 }
00142
00143 int32_t CTron::trcg(float64_t delta, double* g, double* s, double* r)
00144 {
00145
00146 int i, cg_iter;
00147 int n = (int) fun_obj->get_nr_variable();
00148 int inc = 1;
00149 double one = 1;
00150 double *Hd = SG_MALLOC(double, n);
00151 double *d = SG_MALLOC(double, n);
00152 double rTr, rnewTrnew, alpha, beta, cgtol;
00153
00154 for (i=0; i<n; i++)
00155 {
00156 s[i] = 0;
00157 r[i] = -g[i];
00158 d[i] = r[i];
00159 }
00160 cgtol = 0.1* cblas_dnrm2(n, g, inc);
00161
00162 cg_iter = 0;
00163 rTr = cblas_ddot(n, r, inc, r, inc);
00164 while (1)
00165 {
00166 if (cblas_dnrm2(n, r, inc) <= cgtol)
00167 break;
00168 cg_iter++;
00169 fun_obj->Hv(d, Hd);
00170
00171 alpha = rTr/cblas_ddot(n, d, inc, Hd, inc);
00172 cblas_daxpy(n, alpha, d, inc, s, inc);
00173 if (cblas_dnrm2(n, s, inc) > delta)
00174 {
00175 SG_INFO("cg reaches trust region boundary\n");
00176 alpha = -alpha;
00177 cblas_daxpy(n, alpha, d, inc, s, inc);
00178
00179 double std = cblas_ddot(n, s, inc, d, inc);
00180 double sts = cblas_ddot(n, s, inc, s, inc);
00181 double dtd = cblas_ddot(n, d, inc, d, inc);
00182 double dsq = delta*delta;
00183 double rad = sqrt(std*std + dtd*(dsq-sts));
00184 if (std >= 0)
00185 alpha = (dsq - sts)/(std + rad);
00186 else
00187 alpha = (rad - std)/dtd;
00188 cblas_daxpy(n, alpha, d, inc, s, inc);
00189 alpha = -alpha;
00190 cblas_daxpy(n, alpha, Hd, inc, r, inc);
00191 break;
00192 }
00193 alpha = -alpha;
00194 cblas_daxpy(n, alpha, Hd, inc, r, inc);
00195 rnewTrnew = cblas_ddot(n, r, inc, r, inc);
00196 beta = rnewTrnew/rTr;
00197 cblas_dscal(n, beta, d, inc);
00198 cblas_daxpy(n, one, r, inc, d, inc);
00199 rTr = rnewTrnew;
00200 }
00201
00202 SG_FREE(d);
00203 SG_FREE(Hd);
00204
00205 return(cg_iter);
00206 }
00207
00208 float64_t CTron::norm_inf(int32_t n, float64_t *x)
00209 {
00210 float64_t dmax = CMath::abs(x[0]);
00211 for (int32_t i=1; i<n; i++)
00212 if (CMath::abs(x[i]) >= dmax)
00213 dmax = CMath::abs(x[i]);
00214 return(dmax);
00215 }
00216 #endif //HAVE_LAPACK