00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064 #include <algorithm>
00065 #include <cstdio>
00066 #include <cstdlib>
00067 #include <cmath>
00068
00069 #include <shogun/optimization/lbfgs/lbfgs.h>
00070 #include <shogun/lib/SGVector.h>
00071
00072 namespace shogun
00073 {
00074
00075 #define min2(a, b) ((a) <= (b) ? (a) : (b))
00076 #define max2(a, b) ((a) >= (b) ? (a) : (b))
00077 #define max3(a, b, c) max2(max2((a), (b)), (c));
00078
00079 struct tag_callback_data {
00080 int32_t n;
00081 void *instance;
00082 lbfgs_evaluate_t proc_evaluate;
00083 lbfgs_progress_t proc_progress;
00084 };
00085 typedef struct tag_callback_data callback_data_t;
00086
00087 struct tag_iteration_data {
00088 float64_t alpha;
00089 float64_t *s;
00090 float64_t *y;
00091 float64_t ys;
00092 };
00093 typedef struct tag_iteration_data iteration_data_t;
00094
00095 static const lbfgs_parameter_t _defparam = {
00096 6, 1e-5, 0, 1e-5,
00097 0, LBFGS_LINESEARCH_DEFAULT, 40,
00098 1e-20, 1e20, 1e-4, 0.9, 0.9, 1.0e-16,
00099 0.0, 0, -1,
00100 };
00101
00102
00103
00104 typedef int32_t (*line_search_proc)(
00105 int32_t n,
00106 float64_t *x,
00107 float64_t *f,
00108 float64_t *g,
00109 float64_t *s,
00110 float64_t *stp,
00111 const float64_t* xp,
00112 const float64_t* gp,
00113 float64_t *wa,
00114 callback_data_t *cd,
00115 const lbfgs_parameter_t *param
00116 );
00117
00118 static int32_t line_search_backtracking(
00119 int32_t n,
00120 float64_t *x,
00121 float64_t *f,
00122 float64_t *g,
00123 float64_t *s,
00124 float64_t *stp,
00125 const float64_t* xp,
00126 const float64_t* gp,
00127 float64_t *wa,
00128 callback_data_t *cd,
00129 const lbfgs_parameter_t *param
00130 );
00131
00132 static int32_t line_search_backtracking_owlqn(
00133 int32_t n,
00134 float64_t *x,
00135 float64_t *f,
00136 float64_t *g,
00137 float64_t *s,
00138 float64_t *stp,
00139 const float64_t* xp,
00140 const float64_t* gp,
00141 float64_t *wp,
00142 callback_data_t *cd,
00143 const lbfgs_parameter_t *param
00144 );
00145
00146 static int32_t line_search_morethuente(
00147 int32_t n,
00148 float64_t *x,
00149 float64_t *f,
00150 float64_t *g,
00151 float64_t *s,
00152 float64_t *stp,
00153 const float64_t* xp,
00154 const float64_t* gp,
00155 float64_t *wa,
00156 callback_data_t *cd,
00157 const lbfgs_parameter_t *param
00158 );
00159
00160 static int32_t update_trial_interval(
00161 float64_t *x,
00162 float64_t *fx,
00163 float64_t *dx,
00164 float64_t *y,
00165 float64_t *fy,
00166 float64_t *dy,
00167 float64_t *t,
00168 float64_t *ft,
00169 float64_t *dt,
00170 const float64_t tmin,
00171 const float64_t tmax,
00172 int32_t *brackt
00173 );
00174
00175 static float64_t owlqn_x1norm(
00176 const float64_t* x,
00177 const int32_t start,
00178 const int32_t n
00179 );
00180
00181 static void owlqn_pseudo_gradient(
00182 float64_t* pg,
00183 const float64_t* x,
00184 const float64_t* g,
00185 const int32_t n,
00186 const float64_t c,
00187 const int32_t start,
00188 const int32_t end
00189 );
00190
00191 static void owlqn_project(
00192 float64_t* d,
00193 const float64_t* sign,
00194 const int32_t start,
00195 const int32_t end
00196 );
00197
00198
00199 void lbfgs_parameter_init(lbfgs_parameter_t *param)
00200 {
00201 memcpy(param, &_defparam, sizeof(*param));
00202 }
00203
00204 int32_t lbfgs(
00205 int32_t n,
00206 float64_t *x,
00207 float64_t *ptr_fx,
00208 lbfgs_evaluate_t proc_evaluate,
00209 lbfgs_progress_t proc_progress,
00210 void *instance,
00211 lbfgs_parameter_t *_param
00212 )
00213 {
00214 int32_t ret;
00215 int32_t i, j, k, ls, end, bound;
00216 float64_t step;
00217
00218
00219 lbfgs_parameter_t param = (_param != NULL) ? (*_param) : _defparam;
00220 const int32_t m = param.m;
00221
00222 float64_t *xp = NULL;
00223 float64_t *g = NULL, *gp = NULL, *pg = NULL;
00224 float64_t *d = NULL, *w = NULL, *pf = NULL;
00225 iteration_data_t *lm = NULL, *it = NULL;
00226 float64_t ys, yy;
00227 float64_t xnorm, gnorm, beta;
00228 float64_t fx = 0.;
00229 float64_t rate = 0.;
00230 line_search_proc linesearch = line_search_morethuente;
00231
00232
00233 callback_data_t cd;
00234 cd.n = n;
00235 cd.instance = instance;
00236 cd.proc_evaluate = proc_evaluate;
00237 cd.proc_progress = proc_progress;
00238
00239
00240 if (n <= 0) {
00241 return LBFGSERR_INVALID_N;
00242 }
00243 if (param.epsilon < 0.) {
00244 return LBFGSERR_INVALID_EPSILON;
00245 }
00246 if (param.past < 0) {
00247 return LBFGSERR_INVALID_TESTPERIOD;
00248 }
00249 if (param.delta < 0.) {
00250 return LBFGSERR_INVALID_DELTA;
00251 }
00252 if (param.min_step < 0.) {
00253 return LBFGSERR_INVALID_MINSTEP;
00254 }
00255 if (param.max_step < param.min_step) {
00256 return LBFGSERR_INVALID_MAXSTEP;
00257 }
00258 if (param.ftol < 0.) {
00259 return LBFGSERR_INVALID_FTOL;
00260 }
00261 if (param.linesearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE ||
00262 param.linesearch == LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
00263 if (param.wolfe <= param.ftol || 1. <= param.wolfe) {
00264 return LBFGSERR_INVALID_WOLFE;
00265 }
00266 }
00267 if (param.gtol < 0.) {
00268 return LBFGSERR_INVALID_GTOL;
00269 }
00270 if (param.xtol < 0.) {
00271 return LBFGSERR_INVALID_XTOL;
00272 }
00273 if (param.max_linesearch <= 0) {
00274 return LBFGSERR_INVALID_MAXLINESEARCH;
00275 }
00276 if (param.orthantwise_c < 0.) {
00277 return LBFGSERR_INVALID_ORTHANTWISE;
00278 }
00279 if (param.orthantwise_start < 0 || n < param.orthantwise_start) {
00280 return LBFGSERR_INVALID_ORTHANTWISE_START;
00281 }
00282 if (param.orthantwise_end < 0) {
00283 param.orthantwise_end = n;
00284 }
00285 if (n < param.orthantwise_end) {
00286 return LBFGSERR_INVALID_ORTHANTWISE_END;
00287 }
00288 if (param.orthantwise_c != 0.) {
00289 switch (param.linesearch) {
00290 case LBFGS_LINESEARCH_BACKTRACKING:
00291 linesearch = line_search_backtracking_owlqn;
00292 break;
00293 default:
00294
00295 return LBFGSERR_INVALID_LINESEARCH;
00296 }
00297 } else {
00298 switch (param.linesearch) {
00299 case LBFGS_LINESEARCH_MORETHUENTE:
00300 linesearch = line_search_morethuente;
00301 break;
00302 case LBFGS_LINESEARCH_BACKTRACKING_ARMIJO:
00303 case LBFGS_LINESEARCH_BACKTRACKING_WOLFE:
00304 case LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE:
00305 linesearch = line_search_backtracking;
00306 break;
00307 default:
00308 return LBFGSERR_INVALID_LINESEARCH;
00309 }
00310 }
00311
00312
00313 xp = SG_CALLOC(float64_t, n);
00314 g = SG_CALLOC(float64_t, n);
00315 gp = SG_CALLOC(float64_t, n);
00316 d = SG_CALLOC(float64_t, n);
00317 w = SG_CALLOC(float64_t, n);
00318 if (xp == NULL || g == NULL || gp == NULL || d == NULL || w == NULL) {
00319 ret = LBFGSERR_OUTOFMEMORY;
00320 goto lbfgs_exit;
00321 }
00322
00323 if (param.orthantwise_c != 0.) {
00324
00325 pg = SG_CALLOC(float64_t, n);
00326 if (pg == NULL) {
00327 ret = LBFGSERR_OUTOFMEMORY;
00328 goto lbfgs_exit;
00329 }
00330 }
00331
00332
00333 lm = SG_CALLOC(iteration_data_t, m);
00334 if (lm == NULL) {
00335 ret = LBFGSERR_OUTOFMEMORY;
00336 goto lbfgs_exit;
00337 }
00338
00339
00340 for (i = 0;i < m;++i) {
00341 it = &lm[i];
00342 it->alpha = 0;
00343 it->ys = 0;
00344 it->s = SG_CALLOC(float64_t, n);
00345 it->y = SG_CALLOC(float64_t, n);
00346 if (it->s == NULL || it->y == NULL) {
00347 ret = LBFGSERR_OUTOFMEMORY;
00348 goto lbfgs_exit;
00349 }
00350 }
00351
00352
00353 if (0 < param.past) {
00354 pf = SG_CALLOC(float64_t, param.past);
00355 }
00356
00357
00358 fx = cd.proc_evaluate(cd.instance, x, g, cd.n, 0);
00359 if (0. != param.orthantwise_c) {
00360
00361 xnorm = owlqn_x1norm(x, param.orthantwise_start, param.orthantwise_end);
00362 fx += xnorm * param.orthantwise_c;
00363 owlqn_pseudo_gradient(
00364 pg, x, g, n,
00365 param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
00366 );
00367 }
00368
00369
00370 if (pf != NULL) {
00371 pf[0] = fx;
00372 }
00373
00374
00375
00376
00377
00378 if (param.orthantwise_c == 0.) {
00379 std::copy(g,g+n,d);
00380 SGVector<float64_t>::scale_vector(-1, d, n);
00381 } else {
00382 std::copy(pg,pg+n,d);
00383 SGVector<float64_t>::scale_vector(-1, d, n);
00384 }
00385
00386
00387
00388
00389 xnorm = SGVector<float64_t>::twonorm(x, n);
00390 if (param.orthantwise_c == 0.) {
00391 gnorm = SGVector<float64_t>::twonorm(g, n);
00392 } else {
00393 gnorm = SGVector<float64_t>::twonorm(pg, n);
00394 }
00395 if (xnorm < 1.0) xnorm = 1.0;
00396 if (gnorm / xnorm <= param.epsilon) {
00397 ret = LBFGS_ALREADY_MINIMIZED;
00398 goto lbfgs_exit;
00399 }
00400
00401
00402
00403
00404 step = 1.0 / SGVector<float64_t>::twonorm(d, n);
00405
00406 k = 1;
00407 end = 0;
00408 for (;;) {
00409
00410 std::copy(x,x+n,xp);
00411 std::copy(g,g+n,gp);
00412
00413
00414 if (param.orthantwise_c == 0.) {
00415 ls = linesearch(n, x, &fx, g, d, &step, xp, gp, w, &cd, ¶m);
00416 } else {
00417 ls = linesearch(n, x, &fx, g, d, &step, xp, pg, w, &cd, ¶m);
00418 owlqn_pseudo_gradient(
00419 pg, x, g, n,
00420 param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
00421 );
00422 }
00423 if (ls < 0) {
00424
00425 std::copy(xp,xp+n,x);
00426 std::copy(gp,gp+n,g);
00427 ret = ls;
00428 goto lbfgs_exit;
00429 }
00430
00431
00432 xnorm = SGVector<float64_t>::twonorm(x, n);
00433 if (param.orthantwise_c == 0.) {
00434 gnorm = SGVector<float64_t>::twonorm(g, n);
00435 } else {
00436 gnorm = SGVector<float64_t>::twonorm(pg, n);
00437 }
00438
00439
00440 if (cd.proc_progress) {
00441 if ((ret = cd.proc_progress(cd.instance, x, g, fx, xnorm, gnorm, step, cd.n, k, ls))) {
00442 goto lbfgs_exit;
00443 }
00444 }
00445
00446
00447
00448
00449
00450
00451 if (xnorm < 1.0) xnorm = 1.0;
00452 if (gnorm / xnorm <= param.epsilon) {
00453
00454 ret = LBFGS_SUCCESS;
00455 break;
00456 }
00457
00458
00459
00460
00461
00462
00463 if (pf != NULL) {
00464
00465 if (param.past <= k) {
00466
00467 rate = (pf[k % param.past] - fx) / fx;
00468
00469
00470 if (rate < param.delta) {
00471 ret = LBFGS_STOP;
00472 break;
00473 }
00474 }
00475
00476
00477 pf[k % param.past] = fx;
00478 }
00479
00480 if (param.max_iterations != 0 && param.max_iterations < k+1) {
00481
00482 ret = LBFGSERR_MAXIMUMITERATION;
00483 break;
00484 }
00485
00486
00487
00488
00489
00490
00491 it = &lm[end];
00492 SGVector<float64_t>::add(it->s, 1, x, -1, xp, n);
00493 SGVector<float64_t>::add(it->y, 1, g, -1, gp, n);
00494
00495
00496
00497
00498
00499
00500
00501 ys = SGVector<float64_t>::dot(it->y, it->s, n);
00502 yy = SGVector<float64_t>::dot(it->y, it->y, n);
00503 it->ys = ys;
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513 bound = (m <= k) ? m : k;
00514 ++k;
00515 end = (end + 1) % m;
00516
00517
00518 if (param.orthantwise_c == 0.) {
00519
00520 std::copy(g, g+n, d);
00521 SGVector<float64_t>::scale_vector(-1, d, n);
00522 } else {
00523 std::copy(pg, pg+n, d);
00524 SGVector<float64_t>::scale_vector(-1, d, n);
00525 }
00526
00527 j = end;
00528 for (i = 0;i < bound;++i) {
00529 j = (j + m - 1) % m;
00530 it = &lm[j];
00531
00532 it->alpha = SGVector<float64_t>::dot(it->s, d, n);
00533 it->alpha /= it->ys;
00534
00535 SGVector<float64_t>::add(d, 1, d, -it->alpha, it->y, n);
00536 }
00537
00538 SGVector<float64_t>::scale_vector(ys / yy, d, n);
00539
00540 for (i = 0;i < bound;++i) {
00541 it = &lm[j];
00542
00543 beta = SGVector<float64_t>::dot(it->y, d, n);
00544 beta /= it->ys;
00545
00546 SGVector<float64_t>::add(d, 1, d, it->alpha-beta, it->s, n);
00547 j = (j + 1) % m;
00548 }
00549
00550
00551
00552
00553 if (param.orthantwise_c != 0.) {
00554 for (i = param.orthantwise_start;i < param.orthantwise_end;++i) {
00555 if (d[i] * pg[i] >= 0) {
00556 d[i] = 0;
00557 }
00558 }
00559 }
00560
00561
00562
00563
00564 step = 1.0;
00565 }
00566
00567 lbfgs_exit:
00568
00569 if (ptr_fx != NULL) {
00570 *ptr_fx = fx;
00571 }
00572
00573 SG_FREE(pf);
00574
00575
00576 if (lm != NULL) {
00577 for (i = 0;i < m;++i) {
00578 SG_FREE(lm[i].s);
00579 SG_FREE(lm[i].y);
00580 }
00581 SG_FREE(lm);
00582 }
00583 SG_FREE(pg);
00584 SG_FREE(w);
00585 SG_FREE(d);
00586 SG_FREE(gp);
00587 SG_FREE(g);
00588 SG_FREE(xp);
00589
00590 return ret;
00591 }
00592
00593
00594
00595 static int32_t line_search_backtracking(
00596 int32_t n,
00597 float64_t *x,
00598 float64_t *f,
00599 float64_t *g,
00600 float64_t *s,
00601 float64_t *stp,
00602 const float64_t* xp,
00603 const float64_t* gp,
00604 float64_t *wp,
00605 callback_data_t *cd,
00606 const lbfgs_parameter_t *param
00607 )
00608 {
00609 int32_t count = 0;
00610 float64_t width, dg;
00611 float64_t finit, dginit = 0., dgtest;
00612 const float64_t dec = 0.5, inc = 2.1;
00613
00614
00615 if (*stp <= 0.) {
00616 return LBFGSERR_INVALIDPARAMETERS;
00617 }
00618
00619
00620 dginit = SGVector<float64_t>::dot(g, s, n);
00621
00622
00623 if (0 < dginit) {
00624 return LBFGSERR_INCREASEGRADIENT;
00625 }
00626
00627
00628 finit = *f;
00629 dgtest = param->ftol * dginit;
00630
00631 for (;;) {
00632 std::copy(xp,xp+n,x);
00633 SGVector<float64_t>::add(x, 1, x, *stp, s, n);
00634
00635
00636 *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
00637
00638 ++count;
00639
00640 if (*f > finit + *stp * dgtest) {
00641 width = dec;
00642 } else {
00643
00644 if (param->linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO) {
00645
00646 return count;
00647 }
00648
00649
00650 dg = SGVector<float64_t>::dot(g, s, n);
00651 if (dg < param->wolfe * dginit) {
00652 width = inc;
00653 } else {
00654 if(param->linesearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE) {
00655
00656 return count;
00657 }
00658
00659
00660 if(dg > -param->wolfe * dginit) {
00661 width = dec;
00662 } else {
00663
00664 return count;
00665 }
00666 }
00667 }
00668
00669 if (*stp < param->min_step) {
00670
00671 return LBFGSERR_MINIMUMSTEP;
00672 }
00673 if (*stp > param->max_step) {
00674
00675 return LBFGSERR_MAXIMUMSTEP;
00676 }
00677 if (param->max_linesearch <= count) {
00678
00679 return LBFGSERR_MAXIMUMLINESEARCH;
00680 }
00681
00682 (*stp) *= width;
00683 }
00684 }
00685
00686
00687
00688 static int32_t line_search_backtracking_owlqn(
00689 int32_t n,
00690 float64_t *x,
00691 float64_t *f,
00692 float64_t *g,
00693 float64_t *s,
00694 float64_t *stp,
00695 const float64_t* xp,
00696 const float64_t* gp,
00697 float64_t *wp,
00698 callback_data_t *cd,
00699 const lbfgs_parameter_t *param
00700 )
00701 {
00702 int32_t i, count = 0;
00703 float64_t width = 0.5, norm = 0.;
00704 float64_t finit = *f, dgtest;
00705
00706
00707 if (*stp <= 0.) {
00708 return LBFGSERR_INVALIDPARAMETERS;
00709 }
00710
00711
00712 for (i = 0;i < n;++i) {
00713 wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i];
00714 }
00715
00716 for (;;) {
00717
00718 std::copy(xp,xp+n,x);
00719 SGVector<float64_t>::add(x, 1, x, *stp, s, n);
00720
00721
00722 owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
00723
00724
00725 *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
00726
00727
00728 norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end);
00729 *f += norm * param->orthantwise_c;
00730
00731 ++count;
00732
00733 dgtest = 0.;
00734 for (i = 0;i < n;++i) {
00735 dgtest += (x[i] - xp[i]) * gp[i];
00736 }
00737
00738 if (*f <= finit + param->ftol * dgtest) {
00739
00740 return count;
00741 }
00742
00743 if (*stp < param->min_step) {
00744
00745 return LBFGSERR_MINIMUMSTEP;
00746 }
00747 if (*stp > param->max_step) {
00748
00749 return LBFGSERR_MAXIMUMSTEP;
00750 }
00751 if (param->max_linesearch <= count) {
00752
00753 return LBFGSERR_MAXIMUMLINESEARCH;
00754 }
00755
00756 (*stp) *= width;
00757 }
00758 }
00759
00760
00761
00762 static int32_t line_search_morethuente(
00763 int32_t n,
00764 float64_t *x,
00765 float64_t *f,
00766 float64_t *g,
00767 float64_t *s,
00768 float64_t *stp,
00769 const float64_t* xp,
00770 const float64_t* gp,
00771 float64_t *wa,
00772 callback_data_t *cd,
00773 const lbfgs_parameter_t *param
00774 )
00775 {
00776 int32_t count = 0;
00777 int32_t brackt, stage1, uinfo = 0;
00778 float64_t dg;
00779 float64_t stx, fx, dgx;
00780 float64_t sty, fy, dgy;
00781 float64_t fxm, dgxm, fym, dgym, fm, dgm;
00782 float64_t finit, ftest1, dginit, dgtest;
00783 float64_t width, prev_width;
00784 float64_t stmin, stmax;
00785
00786
00787 if (*stp <= 0.) {
00788 return LBFGSERR_INVALIDPARAMETERS;
00789 }
00790
00791
00792 dginit = SGVector<float64_t>::dot(g, s, n);
00793
00794
00795 if (0 < dginit) {
00796 return LBFGSERR_INCREASEGRADIENT;
00797 }
00798
00799
00800 brackt = 0;
00801 stage1 = 1;
00802 finit = *f;
00803 dgtest = param->ftol * dginit;
00804 width = param->max_step - param->min_step;
00805 prev_width = 2.0 * width;
00806
00807
00808
00809
00810
00811
00812
00813
00814
00815
00816 stx = sty = 0.;
00817 fx = fy = finit;
00818 dgx = dgy = dginit;
00819
00820 for (;;) {
00821
00822
00823
00824
00825 if (brackt) {
00826 stmin = min2(stx, sty);
00827 stmax = max2(stx, sty);
00828 } else {
00829 stmin = stx;
00830 stmax = *stp + 4.0 * (*stp - stx);
00831 }
00832
00833
00834 if (*stp < param->min_step) *stp = param->min_step;
00835 if (param->max_step < *stp) *stp = param->max_step;
00836
00837
00838
00839
00840
00841 if ((brackt && ((*stp <= stmin || stmax <= *stp) || param->max_linesearch <= count + 1 || uinfo != 0)) || (brackt && (stmax - stmin <= param->xtol * stmax))) {
00842 *stp = stx;
00843 }
00844
00845
00846
00847
00848
00849 std::copy(xp,xp+n,x);
00850 SGVector<float64_t>::add(x, 1, x, *stp, s, n);
00851
00852
00853 *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
00854 dg = SGVector<float64_t>::dot(g, s, n);
00855
00856 ftest1 = finit + *stp * dgtest;
00857 ++count;
00858
00859
00860 if (brackt && ((*stp <= stmin || stmax <= *stp) || uinfo != 0)) {
00861
00862 return LBFGSERR_ROUNDING_ERROR;
00863 }
00864 if (*stp == param->max_step && *f <= ftest1 && dg <= dgtest) {
00865
00866 return LBFGSERR_MAXIMUMSTEP;
00867 }
00868 if (*stp == param->min_step && (ftest1 < *f || dgtest <= dg)) {
00869
00870 return LBFGSERR_MINIMUMSTEP;
00871 }
00872 if (brackt && (stmax - stmin) <= param->xtol * stmax) {
00873
00874 return LBFGSERR_WIDTHTOOSMALL;
00875 }
00876 if (param->max_linesearch <= count) {
00877
00878 return LBFGSERR_MAXIMUMLINESEARCH;
00879 }
00880 if (*f <= ftest1 && fabs(dg) <= param->gtol * (-dginit)) {
00881
00882 return count;
00883 }
00884
00885
00886
00887
00888
00889 if (stage1 && *f <= ftest1 && min2(param->ftol, param->gtol) * dginit <= dg) {
00890 stage1 = 0;
00891 }
00892
00893
00894
00895
00896
00897
00898
00899
00900 if (stage1 && ftest1 < *f && *f <= fx) {
00901
00902 fm = *f - *stp * dgtest;
00903 fxm = fx - stx * dgtest;
00904 fym = fy - sty * dgtest;
00905 dgm = dg - dgtest;
00906 dgxm = dgx - dgtest;
00907 dgym = dgy - dgtest;
00908
00909
00910
00911
00912
00913 uinfo = update_trial_interval(
00914 &stx, &fxm, &dgxm,
00915 &sty, &fym, &dgym,
00916 stp, &fm, &dgm,
00917 stmin, stmax, &brackt
00918 );
00919
00920
00921 fx = fxm + stx * dgtest;
00922 fy = fym + sty * dgtest;
00923 dgx = dgxm + dgtest;
00924 dgy = dgym + dgtest;
00925 } else {
00926
00927
00928
00929
00930 uinfo = update_trial_interval(
00931 &stx, &fx, &dgx,
00932 &sty, &fy, &dgy,
00933 stp, f, &dg,
00934 stmin, stmax, &brackt
00935 );
00936 }
00937
00938
00939
00940
00941 if (brackt) {
00942 if (0.66 * prev_width <= fabs(sty - stx)) {
00943 *stp = stx + 0.5 * (sty - stx);
00944 }
00945 prev_width = width;
00946 width = fabs(sty - stx);
00947 }
00948 }
00949
00950 return LBFGSERR_LOGICERROR;
00951 }
00952
00953
00954
00958 #define USES_MINIMIZER \
00959 float64_t a, d, gamma, theta, p, q, r, s;
00960
00971 #define CUBIC_MINIMIZER(cm, u, fu, du, v, fv, dv) \
00972 d = (v) - (u); \
00973 theta = ((fu) - (fv)) * 3 / d + (du) + (dv); \
00974 p = fabs(theta); \
00975 q = fabs(du); \
00976 r = fabs(dv); \
00977 s = max3(p, q, r); \
00978 \
00979 a = theta / s; \
00980 gamma = s * sqrt(a * a - ((du) / s) * ((dv) / s)); \
00981 if ((v) < (u)) gamma = -gamma; \
00982 p = gamma - (du) + theta; \
00983 q = gamma - (du) + gamma + (dv); \
00984 r = p / q; \
00985 (cm) = (u) + r * d;
00986
00999 #define CUBIC_MINIMIZER2(cm, u, fu, du, v, fv, dv, xmin, xmax) \
01000 d = (v) - (u); \
01001 theta = ((fu) - (fv)) * 3 / d + (du) + (dv); \
01002 p = fabs(theta); \
01003 q = fabs(du); \
01004 r = fabs(dv); \
01005 s = max3(p, q, r); \
01006 \
01007 a = theta / s; \
01008 gamma = s * sqrt(max2(0, a * a - ((du) / s) * ((dv) / s))); \
01009 if ((u) < (v)) gamma = -gamma; \
01010 p = gamma - (dv) + theta; \
01011 q = gamma - (dv) + gamma + (du); \
01012 r = p / q; \
01013 if (r < 0. && gamma != 0.) { \
01014 (cm) = (v) - r * d; \
01015 } else if (a < 0) { \
01016 (cm) = (xmax); \
01017 } else { \
01018 (cm) = (xmin); \
01019 }
01020
01030 #define QUARD_MINIMIZER(qm, u, fu, du, v, fv) \
01031 a = (v) - (u); \
01032 (qm) = (u) + (du) / (((fu) - (fv)) / a + (du)) / 2 * a;
01033
01042 #define QUARD_MINIMIZER2(qm, u, du, v, dv) \
01043 a = (u) - (v); \
01044 (qm) = (v) + (dv) / ((dv) - (du)) * a;
01045
01046 #define fsigndiff(x, y) (*(x) * (*(y) / fabs(*(y))) < 0.)
01047
01077 static int32_t update_trial_interval(
01078 float64_t *x,
01079 float64_t *fx,
01080 float64_t *dx,
01081 float64_t *y,
01082 float64_t *fy,
01083 float64_t *dy,
01084 float64_t *t,
01085 float64_t *ft,
01086 float64_t *dt,
01087 const float64_t tmin,
01088 const float64_t tmax,
01089 int32_t *brackt
01090 )
01091 {
01092 int32_t bound;
01093 int32_t dsign = fsigndiff(dt, dx);
01094 float64_t mc;
01095 float64_t mq;
01096 float64_t newt;
01097 USES_MINIMIZER;
01098
01099
01100 if (*brackt) {
01101 if (*t <= min2(*x, *y) || max2(*x, *y) <= *t) {
01102
01103 return LBFGSERR_OUTOFINTERVAL;
01104 }
01105 if (0. <= *dx * (*t - *x)) {
01106
01107 return LBFGSERR_INCREASEGRADIENT;
01108 }
01109 if (tmax < tmin) {
01110
01111 return LBFGSERR_INCORRECT_TMINMAX;
01112 }
01113 }
01114
01115
01116
01117
01118 if (*fx < *ft) {
01119
01120
01121
01122
01123
01124
01125 *brackt = 1;
01126 bound = 1;
01127 CUBIC_MINIMIZER(mc, *x, *fx, *dx, *t, *ft, *dt);
01128 QUARD_MINIMIZER(mq, *x, *fx, *dx, *t, *ft);
01129 if (fabs(mc - *x) < fabs(mq - *x)) {
01130 newt = mc;
01131 } else {
01132 newt = mc + 0.5 * (mq - mc);
01133 }
01134 } else if (dsign) {
01135
01136
01137
01138
01139
01140
01141 *brackt = 1;
01142 bound = 0;
01143 CUBIC_MINIMIZER(mc, *x, *fx, *dx, *t, *ft, *dt);
01144 QUARD_MINIMIZER2(mq, *x, *dx, *t, *dt);
01145 if (fabs(mc - *t) > fabs(mq - *t)) {
01146 newt = mc;
01147 } else {
01148 newt = mq;
01149 }
01150 } else if (fabs(*dt) < fabs(*dx)) {
01151
01152
01153
01154
01155
01156
01157
01158
01159
01160
01161
01162 bound = 1;
01163 CUBIC_MINIMIZER2(mc, *x, *fx, *dx, *t, *ft, *dt, tmin, tmax);
01164 QUARD_MINIMIZER2(mq, *x, *dx, *t, *dt);
01165 if (*brackt) {
01166 if (fabs(*t - mc) < fabs(*t - mq)) {
01167 newt = mc;
01168 } else {
01169 newt = mq;
01170 }
01171 } else {
01172 if (fabs(*t - mc) > fabs(*t - mq)) {
01173 newt = mc;
01174 } else {
01175 newt = mq;
01176 }
01177 }
01178 } else {
01179
01180
01181
01182
01183
01184
01185 bound = 0;
01186 if (*brackt) {
01187 CUBIC_MINIMIZER(newt, *t, *ft, *dt, *y, *fy, *dy);
01188 } else if (*x < *t) {
01189 newt = tmax;
01190 } else {
01191 newt = tmin;
01192 }
01193 }
01194
01195
01196
01197
01198
01199
01200
01201
01202
01203
01204
01205
01206 if (*fx < *ft) {
01207
01208 *y = *t;
01209 *fy = *ft;
01210 *dy = *dt;
01211 } else {
01212
01213 if (dsign) {
01214 *y = *x;
01215 *fy = *fx;
01216 *dy = *dx;
01217 }
01218
01219 *x = *t;
01220 *fx = *ft;
01221 *dx = *dt;
01222 }
01223
01224
01225 if (tmax < newt) newt = tmax;
01226 if (newt < tmin) newt = tmin;
01227
01228
01229
01230
01231
01232 if (*brackt && bound) {
01233 mq = *x + 0.66 * (*y - *x);
01234 if (*x < *y) {
01235 if (mq < newt) newt = mq;
01236 } else {
01237 if (newt < mq) newt = mq;
01238 }
01239 }
01240
01241
01242 *t = newt;
01243 return 0;
01244 }
01245
01246
01247
01248
01249
01250 static float64_t owlqn_x1norm(
01251 const float64_t* x,
01252 const int32_t start,
01253 const int32_t n
01254 )
01255 {
01256 int32_t i;
01257 float64_t norm = 0.;
01258
01259 for (i = start;i < n;++i) {
01260 norm += fabs(x[i]);
01261 }
01262
01263 return norm;
01264 }
01265
01266 static void owlqn_pseudo_gradient(
01267 float64_t* pg,
01268 const float64_t* x,
01269 const float64_t* g,
01270 const int32_t n,
01271 const float64_t c,
01272 const int32_t start,
01273 const int32_t end
01274 )
01275 {
01276 int32_t i;
01277
01278
01279 for (i = 0;i < start;++i) {
01280 pg[i] = g[i];
01281 }
01282
01283
01284 for (i = start;i < end;++i) {
01285 if (x[i] < 0.) {
01286
01287 pg[i] = g[i] - c;
01288 } else if (0. < x[i]) {
01289
01290 pg[i] = g[i] + c;
01291 } else {
01292 if (g[i] < -c) {
01293
01294 pg[i] = g[i] + c;
01295 } else if (c < g[i]) {
01296
01297 pg[i] = g[i] - c;
01298 } else {
01299 pg[i] = 0.;
01300 }
01301 }
01302 }
01303
01304 for (i = end;i < n;++i) {
01305 pg[i] = g[i];
01306 }
01307 }
01308
01309 static void owlqn_project(
01310 float64_t* d,
01311 const float64_t* sign,
01312 const int32_t start,
01313 const int32_t end
01314 )
01315 {
01316 int32_t i;
01317
01318 for (i = start;i < end;++i) {
01319 if (d[i] * sign[i] <= 0) {
01320 d[i] = 0;
01321 }
01322 }
01323 }
01324
01325 }