00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef _SVMLight_H___
00023 #define _SVMLight_H___
00024
00025 #include <shogun/lib/config.h>
00026
00027 #ifdef USE_SVMLIGHT
00028 #include <shogun/classifier/svm/SVM.h>
00029 #include <shogun/kernel/Kernel.h>
00030 #include <shogun/mathematics/Math.h>
00031 #include <shogun/lib/common.h>
00032
00033 #include <stdio.h>
00034 #include <ctype.h>
00035 #include <string.h>
00036 #include <stdlib.h>
00037 #include <time.h>
00038
00039 namespace shogun
00040 {
00041
00042
00043
00044 # define DEF_PRECISION 1E-14
00045 # define MAXSHRINK 50000
00046
00047 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00048
00049 struct MODEL {
00051 int32_t sv_num;
00053 int32_t at_upper_bound;
00055 float64_t b;
00057 int32_t* supvec;
00059 float64_t *alpha;
00061 int32_t *index;
00063 int32_t totdoc;
00065 CKernel* kernel;
00066
00067
00069 float64_t loo_error;
00071 float64_t loo_recall;
00073 float64_t loo_precision;
00074
00076 float64_t xa_error;
00078 float64_t xa_recall;
00080 float64_t xa_precision;
00081 };
00082
00084 typedef struct quadratic_program {
00086 int32_t opt_n;
00088 int32_t opt_m;
00090 float64_t *opt_ce;
00092 float64_t *opt_ce0;
00094 float64_t *opt_g;
00096 float64_t *opt_g0;
00098 float64_t *opt_xinit;
00100 float64_t *opt_low;
00102 float64_t *opt_up;
00103 } QP;
00104
00106 typedef int32_t FNUM;
00107
00109 typedef float64_t FVAL;
00110
00112 struct LEARN_PARM {
00114 int32_t type;
00116 float64_t svm_c;
00118 float64_t* eps;
00120 float64_t svm_costratio;
00122 float64_t transduction_posratio;
00123
00125 int32_t biased_hyperplane;
00130 int32_t sharedslack;
00132 int32_t svm_maxqpsize;
00134 int32_t svm_newvarsinqp;
00136 int32_t kernel_cache_size;
00138 float64_t epsilon_crit;
00140 float64_t epsilon_shrink;
00142 int32_t svm_iter_to_shrink;
00146 int32_t maxiter;
00148 int32_t remove_inconsistent;
00152 int32_t skip_final_opt_check;
00154 int32_t compute_loo;
00158 float64_t rho;
00162 int32_t xa_depth;
00164 char predfile[200];
00168 char alphafile[200];
00169
00170
00172 float64_t epsilon_const;
00174 float64_t epsilon_a;
00176 float64_t opt_precision;
00177
00178
00180 int32_t svm_c_steps;
00182 float64_t svm_c_factor;
00184 float64_t svm_costratio_unlab;
00186 float64_t svm_unlabbound;
00188 float64_t *svm_cost;
00189 };
00190
00192 struct TIMING {
00194 int32_t time_kernel;
00196 int32_t time_opti;
00198 int32_t time_shrink;
00200 int32_t time_update;
00202 int32_t time_model;
00204 int32_t time_check;
00206 int32_t time_select;
00207 };
00208
00209
00211 struct SHRINK_STATE
00212 {
00214 int32_t *active;
00216 int32_t *inactive_since;
00218 int32_t deactnum;
00220 float64_t **a_history;
00222 int32_t maxhistory;
00224 float64_t *last_a;
00226 float64_t *last_lin;
00227 };
00228 #endif // DOXYGEN_SHOULD_SKIP_THIS
00229
00231 class CSVMLight : public CSVM
00232 {
00233 public:
00235 CSVMLight();
00236
00243 CSVMLight(float64_t C, CKernel* k, CLabels* lab);
00244 virtual ~CSVMLight();
00245
00247 void init();
00248
00253 virtual inline EClassifierType get_classifier_type() { return CT_LIGHT; }
00254
00259 int32_t get_runtime();
00260
00261
00263 void svm_learn();
00264
00281 int32_t optimize_to_convergence(
00282 int32_t* docs, int32_t* label, int32_t totdoc, SHRINK_STATE *shrink_state,
00283 int32_t *inconsistent, float64_t *a, float64_t *lin, float64_t *c,
00284 TIMING *timing_profile, float64_t *maxdiff, int32_t heldout,
00285 int32_t retrain);
00286
00297 virtual float64_t compute_objective_function(
00298 float64_t *a, float64_t *lin, float64_t *c, float64_t* eps, int32_t *label,
00299 int32_t totdoc);
00300
00305 void clear_index(int32_t *index);
00306
00312 void add_to_index(int32_t *index, int32_t elem);
00313
00321 int32_t compute_index(int32_t *binfeature, int32_t range, int32_t *index);
00322
00341 void optimize_svm(
00342 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const,
00343 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t totdoc,
00344 int32_t *working2dnum, int32_t varnum, float64_t *a, float64_t *lin,
00345 float64_t *c, float64_t *aicache, QP *qp, float64_t *epsilon_crit_target);
00346
00364 void compute_matrices_for_optimization(
00365 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const,
00366 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t *key,
00367 float64_t *a, float64_t *lin, float64_t *c, int32_t varnum, int32_t totdoc,
00368 float64_t *aicache, QP *qp);
00369
00387 void compute_matrices_for_optimization_parallel(
00388 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const,
00389 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t *key,
00390 float64_t *a, float64_t *lin, float64_t *c, int32_t varnum, int32_t totdoc,
00391 float64_t *aicache, QP *qp);
00392
00405 int32_t calculate_svm_model(
00406 int32_t* docs, int32_t *label,float64_t *lin, float64_t *a,
00407 float64_t* a_old, float64_t *c, int32_t *working2dnum, int32_t *active2dnum);
00408
00425 int32_t check_optimality(
00426 int32_t *label, float64_t *a, float64_t* lin, float64_t *c, int32_t totdoc,
00427 float64_t *maxdiff, float64_t epsilon_crit_org, int32_t *misclassified,
00428 int32_t *inconsistent,int32_t* active2dnum, int32_t *last_suboptimal_at,
00429 int32_t iteration);
00430
00444 virtual void update_linear_component(
00445 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a,
00446 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin,
00447 float64_t *aicache, float64_t* c);
00448
00453 static void* update_linear_component_mkl_linadd_helper(void* p);
00454
00467 void update_linear_component_mkl(
00468 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a,
00469 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin,
00470 float64_t *aicache);
00471
00484 void update_linear_component_mkl_linadd(
00485 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a,
00486 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin,
00487 float64_t *aicache);
00488
00489 void call_mkl_callback(float64_t* a, int32_t* label, float64_t* lin);
00490
00509 int32_t select_next_qp_subproblem_grad(
00510 int32_t *label, float64_t *a, float64_t* lin, float64_t* c, int32_t totdoc,
00511 int32_t qp_size, int32_t *inconsistent, int32_t* active2dnum,
00512 int32_t* working2dnum, float64_t *selcrit, int32_t *select,
00513 int32_t cache_only, int32_t *key, int32_t *chosen);
00514
00533 int32_t select_next_qp_subproblem_rand(
00534 int32_t* label, float64_t *a, float64_t *lin, float64_t *c,
00535 int32_t totdoc, int32_t qp_size, int32_t *inconsistent,
00536 int32_t *active2dnum, int32_t *working2dnum, float64_t *selcrit,
00537 int32_t *select, int32_t *key, int32_t *chosen, int32_t iteration);
00538
00546 void select_top_n(
00547 float64_t *selcrit, int32_t range, int32_t *select, int32_t n);
00548
00555 void init_shrink_state(
00556 SHRINK_STATE *shrink_state, int32_t totdoc, int32_t maxhistory);
00557
00562 void shrink_state_cleanup(SHRINK_STATE *shrink_state);
00563
00579 int32_t shrink_problem(
00580 SHRINK_STATE *shrink_state, int32_t *active2dnum,
00581 int32_t *last_suboptimal_at, int32_t iteration, int32_t totdoc,
00582 int32_t minshrink, float64_t *a, int32_t *inconsistent, float64_t* c,
00583 float64_t* lin, int* label);
00584
00599 virtual void reactivate_inactive_examples(
00600 int32_t *label,float64_t *a,SHRINK_STATE *shrink_state, float64_t *lin,
00601 float64_t *c, int32_t totdoc,int32_t iteration, int32_t *inconsistent,
00602 int32_t *docs,float64_t *aicache, float64_t* maxdiff);
00603
00604 protected:
00611 inline virtual float64_t compute_kernel(int32_t i, int32_t j)
00612 {
00613 return kernel->kernel(i, j);
00614 }
00615
00620 static void* compute_kernel_helper(void* p);
00621
00626 static void* update_linear_component_linadd_helper(void* p);
00627
00632 static void* reactivate_inactive_examples_vanilla_helper(void* p);
00633
00638 static void* reactivate_inactive_examples_linadd_helper(void* p);
00639
00641 inline virtual const char* get_name() const { return "SVMLight"; }
00642
00643
00644 float64_t *optimize_qp( QP *qp,float64_t *epsilon_crit, int32_t nx,
00645 float64_t *threshold, int32_t& svm_maxqpsize);
00646
00655 virtual bool train_machine(CFeatures* data=NULL);
00656
00657 protected:
00659 MODEL* model;
00661 LEARN_PARM* learn_parm;
00663 int32_t verbosity;
00664
00666 float64_t init_margin;
00668 int32_t init_iter;
00670 int32_t precision_violations;
00672 float64_t model_b;
00674 float64_t opt_precision;
00676 float64_t* primal;
00678 float64_t* dual;
00679
00680
00681
00685 float64_t* W;
00687 int32_t count;
00689 float64_t mymaxdiff;
00691 bool use_kernel_cache;
00693 bool mkl_converged;
00694 };
00695 }
00696 #endif //USE_SVMLIGHT
00697 #endif //_SVMLight_H___