00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #ifndef _MULTICLASSLIBLINEAR_H___ 00012 #define _MULTICLASSLIBLINEAR_H___ 00013 #include <shogun/lib/config.h> 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/DotFeatures.h> 00016 #include <shogun/machine/LinearMulticlassMachine.h> 00017 #include <shogun/optimization/liblinear/shogun_liblinear.h> 00018 00019 namespace shogun 00020 { 00021 00036 class CMulticlassLibLinear : public CLinearMulticlassMachine 00037 { 00038 public: 00039 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00040 00041 00042 CMulticlassLibLinear(); 00043 00049 CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs); 00050 00052 virtual ~CMulticlassLibLinear(); 00053 00055 virtual const char* get_name() const 00056 { 00057 return "MulticlassLibLinear"; 00058 } 00059 00063 inline void set_C(float64_t C) 00064 { 00065 ASSERT(C>0); 00066 m_C = C; 00067 } 00071 inline float64_t get_C() const { return m_C; } 00072 00076 inline void set_epsilon(float64_t epsilon) 00077 { 00078 ASSERT(epsilon>0); 00079 m_epsilon = epsilon; 00080 } 00084 inline float64_t get_epsilon() const { return m_epsilon; } 00085 00089 inline void set_use_bias(bool use_bias) 00090 { 00091 m_use_bias = use_bias; 00092 } 00096 inline bool get_use_bias() const 00097 { 00098 return m_use_bias; 00099 } 00100 00104 inline void set_save_train_state(bool save_train_state) 00105 { 00106 m_save_train_state = save_train_state; 00107 } 00111 inline bool get_save_train_state() const 00112 { 00113 return m_save_train_state; 00114 } 00115 00119 inline void set_max_iter(int32_t max_iter) 00120 { 00121 ASSERT(max_iter>0); 00122 m_max_iter = max_iter; 00123 } 00127 inline int32_t get_max_iter() const { return m_max_iter; } 00128 00130 void reset_train_state() 00131 { 00132 if (m_train_state) 00133 { 00134 delete m_train_state; 00135 m_train_state = NULL; 00136 } 00137 } 00138 00142 SGVector<int32_t> get_support_vectors() const; 00143 00144 protected: 00145 00147 virtual bool train_machine(CFeatures* data = NULL); 00148 00150 virtual SGMatrix<float64_t> obtain_regularizer_matrix() const; 00151 00152 private: 00153 00155 void init_defaults(); 00156 00158 void register_parameters(); 00159 00160 protected: 00161 00163 float64_t m_C; 00164 00166 float64_t m_epsilon; 00167 00169 int32_t m_max_iter; 00170 00172 bool m_use_bias; 00173 00175 bool m_save_train_state; 00176 00178 mcsvm_state* m_train_state; 00179 }; 00180 } 00181 #endif