00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071 #include <math.h>
00072 #include <stdlib.h>
00073 #include <stdio.h>
00074 #include <string.h>
00075 #include <stdint.h>
00076 #include <limits.h>
00077
00078 #include <shogun/lib/common.h>
00079 #include <shogun/classifier/svm/libqp.h>
00080 namespace shogun
00081 {
00082
00083 libqp_state_T libqp_splx_solver(const float64_t* (*get_col)(uint32_t),
00084 float64_t *diag_H,
00085 float64_t *f,
00086 float64_t *b,
00087 uint32_t *I,
00088 uint8_t *S,
00089 float64_t *x,
00090 uint32_t n,
00091 uint32_t MaxIter,
00092 float64_t TolAbs,
00093 float64_t TolRel,
00094 float64_t QP_TH,
00095 void (*print_state)(libqp_state_T state))
00096 {
00097 float64_t *d;
00098 float64_t *col_u, *col_v;
00099 float64_t *x_neq;
00100 float64_t tmp;
00101 float64_t improv;
00102 float64_t tmp_num;
00103 float64_t tmp_den=0;
00104 float64_t tau=0;
00105 float64_t delta;
00106 uint32_t *inx;
00107 uint32_t *nk;
00108 uint32_t m;
00109 int32_t u=0;
00110 int32_t v=0;
00111 uint32_t k;
00112 uint32_t i, j;
00113 libqp_state_T state;
00114
00115
00116
00117
00118
00119 state.nIter = 0;
00120 state.QP = LIBQP_PLUS_INF;
00121 state.QD = -LIBQP_PLUS_INF;
00122 state.exitflag = 100;
00123
00124 inx=NULL;
00125 nk=NULL;
00126 d=NULL;
00127 x_neq = NULL;
00128
00129
00130 for( i=0, m=0; i < n; i++ )
00131 m = LIBQP_MAX(m,I[i]);
00132
00133
00134 x_neq = (float64_t*) LIBQP_CALLOC(m, sizeof(float64_t));
00135 if( x_neq == NULL )
00136 {
00137 state.exitflag=-1;
00138 goto cleanup;
00139 }
00140
00141
00142 inx = (uint32_t*) LIBQP_CALLOC(m*n, sizeof(uint32_t));
00143 if( inx == NULL )
00144 {
00145 state.exitflag=-1;
00146 goto cleanup;
00147 }
00148
00149
00150 nk = (uint32_t*) LIBQP_CALLOC(m, sizeof(uint32_t));
00151 if( nk == NULL )
00152 {
00153 state.exitflag=-1;
00154 goto cleanup;
00155 }
00156
00157
00158 for( i=0; i < m; i++ )
00159 x_neq[i] = b[i];
00160
00161
00162
00163 for( i=0; i < n; i++ ) {
00164 k = I[i]-1;
00165 inx[LIBQP_INDEX(nk[k],k,n)] = i;
00166 nk[k]++;
00167
00168 if(S[k] != 0)
00169 x_neq[k] -= x[i];
00170 }
00171
00172
00173 d = (float64_t*) LIBQP_CALLOC(n, sizeof(float64_t));
00174 if( d == NULL )
00175 {
00176 state.exitflag=-1;
00177 goto cleanup;
00178 }
00179
00180
00181 for( i=0; i < n; i++ )
00182 {
00183 d[i] += f[i];
00184 if( x[i] > 0 ) {
00185 col_u = (float64_t*)get_col(i);
00186 for( j=0; j < n; j++ ) {
00187 d[j] += col_u[j]*x[i];
00188 }
00189 }
00190 }
00191
00192
00193
00194 for( i=0, state.QP = 0, state.QD=0; i < n; i++)
00195 {
00196 state.QP += x[i]*(f[i]+d[i]);
00197 state.QD += x[i]*(f[i]-d[i]);
00198 }
00199 state.QP = 0.5*state.QP;
00200 state.QD = 0.5*state.QD;
00201
00202 for( i=0; i < m; i++ )
00203 {
00204 for( j=0, tmp = LIBQP_PLUS_INF; j < nk[i]; j++ )
00205 tmp = LIBQP_MIN(tmp, d[inx[LIBQP_INDEX(j,i,n)]]);
00206
00207 if(S[i] == 0)
00208 state.QD += b[i]*tmp;
00209 else
00210 state.QD += b[i]*LIBQP_MIN(tmp,0);
00211 }
00212
00213
00214 if( print_state != NULL)
00215 print_state( state );
00216
00217
00218
00219
00220 while( state.exitflag == 100 )
00221 {
00222 state.nIter ++;
00223
00224
00225 for( k=0; k < m; k++ )
00226 {
00227
00228
00229
00230 for( j=0, tmp = LIBQP_PLUS_INF, delta = 0; j < nk[k]; j++ )
00231 {
00232 i = inx[LIBQP_INDEX(j,k,n)];
00233 delta += x[i]*d[i];
00234 if( tmp > d[i] ) {
00235 tmp = d[i];
00236 u = i;
00237 }
00238 }
00239
00240 if(S[k] != 0 && d[u] > 0)
00241 u = -1;
00242 else
00243 delta -= b[k]*d[u];
00244
00245
00246 if( delta > TolAbs/m && delta > TolRel*LIBQP_ABS(state.QP)/m)
00247 {
00248
00249 if( u != -1 )
00250 {
00251 col_u = (float64_t*)get_col(u);
00252 improv = -LIBQP_PLUS_INF;
00253 for( j=0; j < nk[k]; j++ )
00254 {
00255 i = inx[LIBQP_INDEX(j,k,n)];
00256
00257 if(x[i] > 0 && i != uint32_t(u))
00258 {
00259 tmp_num = x[i]*(d[i] - d[u]);
00260 tmp_den = x[i]*x[i]*(diag_H[u] - 2*col_u[i] + diag_H[i]);
00261 if( tmp_den > 0 )
00262 {
00263 if( tmp_num < tmp_den )
00264 tmp = tmp_num*tmp_num / tmp_den;
00265 else
00266 tmp = tmp_num - 0.5 * tmp_den;
00267
00268 if( tmp > improv )
00269 {
00270 improv = tmp;
00271 tau = LIBQP_MIN(1,tmp_num/tmp_den);
00272 v = i;
00273 }
00274 }
00275 }
00276 }
00277
00278
00279 if(x_neq[k] > 0 && S[k] != 0)
00280 {
00281 tmp_num = -x_neq[k]*d[u];
00282 tmp_den = x_neq[k]*x_neq[k]*diag_H[u];
00283 if( tmp_den > 0 )
00284 {
00285 if( tmp_num < tmp_den )
00286 tmp = tmp_num*tmp_num / tmp_den;
00287 else
00288 tmp = tmp_num - 0.5 * tmp_den;
00289
00290 if( tmp > improv )
00291 {
00292 improv = tmp;
00293 tau = LIBQP_MIN(1,tmp_num/tmp_den);
00294 v = -1;
00295 }
00296 }
00297 }
00298
00299
00300 if(v != -1)
00301 {
00302 tmp = x[v]*tau;
00303 x[u] += tmp;
00304 x[v] -= tmp;
00305
00306
00307 col_v = (float64_t*)get_col(v);
00308 for(i = 0; i < n; i++ )
00309 d[i] += tmp*(col_u[i]-col_v[i]);
00310 }
00311 else
00312 {
00313 tmp = x_neq[k]*tau;
00314 x[u] += tmp;
00315 x_neq[k] -= tmp;
00316
00317
00318 for(i = 0; i < n; i++ )
00319 d[i] += tmp*col_u[i];
00320 }
00321 }
00322 else
00323 {
00324 improv = -LIBQP_PLUS_INF;
00325 for( j=0; j < nk[k]; j++ )
00326 {
00327 i = inx[LIBQP_INDEX(j,k,n)];
00328
00329 if(x[i] > 0)
00330 {
00331 tmp_num = x[i]*d[i];
00332 tmp_den = x[i]*x[i]*diag_H[i];
00333 if( tmp_den > 0 )
00334 {
00335 if( tmp_num < tmp_den )
00336 tmp = tmp_num*tmp_num / tmp_den;
00337 else
00338 tmp = tmp_num - 0.5 * tmp_den;
00339
00340 if( tmp > improv )
00341 {
00342 improv = tmp;
00343 tau = LIBQP_MIN(1,tmp_num/tmp_den);
00344 v = i;
00345 }
00346 }
00347 }
00348 }
00349
00350 tmp = x[v]*tau;
00351 x_neq[k] += tmp;
00352 x[v] -= tmp;
00353
00354
00355 col_v = (float64_t*)get_col(v);
00356 for(i = 0; i < n; i++ )
00357 d[i] -= tmp*col_v[i];
00358 }
00359
00360
00361 state.QP = state.QP - improv;
00362 }
00363 }
00364
00365
00366 for( i=0, state.QP = 0, state.QD=0; i < n; i++)
00367 {
00368 state.QP += x[i]*(f[i]+d[i]);
00369 state.QD += x[i]*(f[i]-d[i]);
00370 }
00371 state.QP = 0.5*state.QP;
00372 state.QD = 0.5*state.QD;
00373
00374 for( k=0; k < m; k++ )
00375 {
00376 for( j=0,tmp = LIBQP_PLUS_INF; j < nk[k]; j++ ) {
00377 i = inx[LIBQP_INDEX(j,k,n)];
00378 tmp = LIBQP_MIN(tmp, d[i]);
00379 }
00380
00381 if(S[k] == 0)
00382 state.QD += b[k]*tmp;
00383 else
00384 state.QD += b[k]*LIBQP_MIN(tmp,0);
00385 }
00386
00387
00388 if( print_state != NULL)
00389 print_state( state );
00390
00391
00392 if(state.QP-state.QD <= LIBQP_ABS(state.QP)*TolRel ) state.exitflag = 1;
00393 else if( state.QP-state.QD <= TolAbs ) state.exitflag = 2;
00394 else if( state.QP <= QP_TH ) state.exitflag = 3;
00395 else if( state.nIter >= MaxIter) state.exitflag = 0;
00396 }
00397
00398
00399
00400
00401 cleanup:
00402 LIBQP_FREE( d );
00403 LIBQP_FREE( inx );
00404 LIBQP_FREE( nk );
00405 LIBQP_FREE( x_neq );
00406
00407 return( state );
00408 }
00409 }
00410