Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LIBLINEAR_H___
00013 #define _LIBLINEAR_H___
00014
00015 #include <shogun/lib/config.h>
00016
00017 #include <shogun/lib/common.h>
00018 #include <shogun/base/Parameter.h>
00019 #include <shogun/machine/LinearMachine.h>
00020 #include <shogun/optimization/liblinear/shogun_liblinear.h>
00021
00022 namespace shogun
00023 {
00025 enum LIBLINEAR_SOLVER_TYPE
00026 {
00028 L2R_LR,
00030 L2R_L2LOSS_SVC_DUAL,
00032 L2R_L2LOSS_SVC,
00034
00035 L2R_L1LOSS_SVC_DUAL,
00037 L1R_L2LOSS_SVC,
00039 L1R_LR,
00041 L2R_LR_DUAL
00042 };
00043
00045 class CLibLinear : public CLinearMachine
00046 {
00047 public:
00048 MACHINE_PROBLEM_TYPE(PT_BINARY);
00049
00051 CLibLinear();
00052
00057 CLibLinear(LIBLINEAR_SOLVER_TYPE liblinear_solver_type);
00058
00065 CLibLinear(
00066 float64_t C, CDotFeatures* traindat,
00067 CLabels* trainlab);
00068
00070 virtual ~CLibLinear();
00071
00072 inline LIBLINEAR_SOLVER_TYPE get_liblinear_solver_type()
00073 {
00074 return liblinear_solver_type;
00075 }
00076
00077 inline void set_liblinear_solver_type(LIBLINEAR_SOLVER_TYPE st)
00078 {
00079 liblinear_solver_type=st;
00080 }
00081
00086 virtual EMachineType get_classifier_type() { return CT_LIBLINEAR; }
00087
00093 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
00094
00099 inline float64_t get_C1() { return C1; }
00100
00105 inline float64_t get_C2() { return C2; }
00106
00111 inline void set_epsilon(float64_t eps) { epsilon=eps; }
00112
00117 inline float64_t get_epsilon() { return epsilon; }
00118
00123 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00124
00129 inline bool get_bias_enabled() { return use_bias; }
00130
00132 virtual const char* get_name() const { return "LibLinear"; }
00133
00135 inline int32_t get_max_iterations()
00136 {
00137 return max_iterations;
00138 }
00139
00141 inline void set_max_iterations(int32_t max_iter=1000)
00142 {
00143 max_iterations=max_iter;
00144 }
00145
00147 void set_linear_term(const SGVector<float64_t> linear_term);
00148
00150 SGVector<float64_t> get_linear_term();
00151
00153 void init_linear_term();
00154
00155 protected:
00164 virtual bool train_machine(CFeatures* data=NULL);
00165
00166 private:
00168 void init();
00169
00170 void train_one(const problem *prob, const parameter *param, double Cp, double Cn);
00171 void solve_l2r_l1l2_svc(
00172 const problem *prob, double eps, double Cp, double Cn, LIBLINEAR_SOLVER_TYPE st);
00173
00174 void solve_l1r_l2_svc(problem *prob_col, double eps, double Cp, double Cn);
00175 void solve_l1r_lr(const problem *prob_col, double eps, double Cp, double Cn);
00176 void solve_l2r_lr_dual(const problem *prob, double eps, double Cp, double Cn);
00177
00178
00179 protected:
00181 float64_t C1;
00183 float64_t C2;
00185 bool use_bias;
00187 float64_t epsilon;
00189 int32_t max_iterations;
00190
00192 SGVector<float64_t> m_linear_term;
00193
00195 LIBLINEAR_SOLVER_TYPE liblinear_solver_type;
00196 };
00197
00198 }
00199
00200 #endif //_LIBLINEAR_H___