Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00034
00035 #ifndef _LIBLINEAR_H
00036 #define _LIBLINEAR_H
00037
00038 #include <shogun/lib/config.h>
00039
00040 #include <shogun/optimization/liblinear/tron.h>
00041 #include <shogun/features/DotFeatures.h>
00042 #include <vector>
00043
00044 namespace shogun
00045 {
00046
00047 #ifdef __cplusplus
00048 extern "C" {
00049 #endif
00050
00052 struct problem
00053 {
00055 int32_t l;
00057 int32_t n;
00059 float64_t* y;
00061 CDotFeatures* x;
00063 bool use_bias;
00064 };
00065
00067 struct parameter
00068 {
00070 int32_t solver_type;
00071
00072
00074 float64_t eps;
00076 float64_t C;
00078 int32_t nr_weight;
00080 int32_t *weight_label;
00082 float64_t* weight;
00083 };
00084
00086 struct model
00087 {
00089 struct parameter param;
00091 int32_t nr_class;
00093 int32_t nr_feature;
00095 float64_t *w;
00097 int32_t *label;
00099 float64_t bias;
00100 };
00101
00102 void destroy_model(struct model *model_);
00103 void destroy_param(struct parameter *param);
00104 #ifdef __cplusplus
00105 }
00106 #endif
00107
00109 class l2loss_svm_fun : public function
00110 {
00111 public:
00118 l2loss_svm_fun(const problem *prob, float64_t Cp, float64_t Cn);
00119 ~l2loss_svm_fun();
00120
00126 float64_t fun(float64_t *w);
00127
00133 void grad(float64_t *w, float64_t *g);
00134
00140 void Hv(float64_t *s, float64_t *Hs);
00141
00146 int32_t get_nr_variable();
00147
00148 private:
00149 void Xv(float64_t *v, float64_t *Xv);
00150 void subXv(float64_t *v, float64_t *Xv);
00151 void subXTv(float64_t *v, float64_t *XTv);
00152
00153 float64_t *C;
00154 float64_t *z;
00155 float64_t *D;
00156 int32_t *I;
00157 int32_t sizeI;
00158 const problem *prob;
00159 };
00160
00162 class l2r_lr_fun : public function
00163 {
00164 public:
00171 l2r_lr_fun(const problem *prob, float64_t* C);
00172 ~l2r_lr_fun();
00173
00179 float64_t fun(float64_t *w);
00180
00186 void grad(float64_t *w, float64_t *g);
00187
00193 void Hv(float64_t *s, float64_t *Hs);
00194
00195 int32_t get_nr_variable();
00196
00197 private:
00198 void Xv(float64_t *v, float64_t *Xv);
00199 void XTv(float64_t *v, float64_t *XTv);
00200
00201 float64_t *C;
00202 float64_t *z;
00203 float64_t *D;
00204 const problem *m_prob;
00205 };
00206
00207 class l2r_l2_svc_fun : public function
00208 {
00209 public:
00210 l2r_l2_svc_fun(const problem *prob, float64_t* Cs);
00211 ~l2r_l2_svc_fun();
00212
00213 double fun(double *w);
00214 void grad(double *w, double *g);
00215 void Hv(double *s, double *Hs);
00216
00217 int get_nr_variable();
00218
00219 protected:
00220 void Xv(double *v, double *Xv);
00221 void subXv(double *v, double *Xv);
00222 void subXTv(double *v, double *XTv);
00223
00224 double *C;
00225 double *z;
00226 double *D;
00227 int *I;
00228 int sizeI;
00229 const problem *m_prob;
00230 };
00231
00232 class l2r_l2_svr_fun: public l2r_l2_svc_fun
00233 {
00234 public:
00235 l2r_l2_svr_fun(const problem *prob, double *Cs, double p);
00236
00237 double fun(double *w);
00238 void grad(double *w, double *g);
00239
00240 private:
00241 double m_p;
00242 };
00243
00244
00245 struct mcsvm_state
00246 {
00247 double* w;
00248 double* B;
00249 double* G;
00250 double* alpha;
00251 double* alpha_new;
00252 int* index;
00253 double* QD;
00254 int* d_ind;
00255 double* d_val;
00256 int* alpha_index;
00257 int* y_index;
00258 int* active_size_i;
00259 bool allocated,inited;
00260
00261 mcsvm_state()
00262 {
00263 w = NULL;
00264 B = NULL;
00265 G = NULL;
00266 alpha = NULL;
00267 alpha_new = NULL;
00268 index = NULL;
00269 QD = NULL;
00270 d_ind = NULL;
00271 d_val = NULL;
00272 alpha_index = NULL;
00273 y_index = NULL;
00274 active_size_i = NULL;
00275 allocated = false;
00276 inited = false;
00277 }
00278
00279 ~mcsvm_state()
00280 {
00281 SG_FREE(w);
00282 SG_FREE(B);
00283 SG_FREE(G);
00284 SG_FREE(alpha);
00285 SG_FREE(alpha_new);
00286 SG_FREE(index);
00287 SG_FREE(QD);
00288 SG_FREE(d_ind);
00289 SG_FREE(d_val);
00290 SG_FREE(alpha_index);
00291 SG_FREE(y_index);
00292 SG_FREE(active_size_i);
00293 }
00294 };
00295
00296 class Solver_MCSVM_CS
00297 {
00298 public:
00299 Solver_MCSVM_CS(const problem *prob, int nr_class, double *C,
00300 double *w0, double eps, int max_iter,
00301 double train_time, mcsvm_state* given_state);
00302 ~Solver_MCSVM_CS();
00303 void solve();
00304 private:
00305 void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double *alpha_new);
00306 bool be_shrunk(int i, int m, int yi, double alpha_i, double minG);
00307 double *C;
00308 int w_size, l;
00309 int nr_class;
00310 int max_iter;
00311 double eps;
00312 double max_train_time;
00313 double* w0;
00314 const problem *prob;
00315 mcsvm_state* state;
00316 };
00317
00318
00319 }
00320 #endif //_LIBLINEAR_H
00321
00322 #endif // DOXYGEN_SHOULD_SKIP_THIS