00001 #ifndef _SVMSGD_H___ 00002 #define _SVMSGD_H___ 00003 00004 /* 00005 SVM with stochastic gradient 00006 Copyright (C) 2007- Leon Bottou 00007 00008 This program is free software; you can redistribute it and/or 00009 modify it under the terms of the GNU Lesser General Public 00010 License as published by the Free Software Foundation; either 00011 version 2.1 of the License, or (at your option) any later version. 00012 00013 This program is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 GNU General Public License for more details. 00017 00018 You should have received a copy of the GNU General Public License 00019 along with this program; if not, write to the Free Software 00020 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00021 00022 Shogun adjustments (w) 2008 Soeren Sonnenburg 00023 */ 00024 00025 #include <shogun/lib/common.h> 00026 #include <shogun/machine/LinearMachine.h> 00027 #include <shogun/features/DotFeatures.h> 00028 #include <shogun/labels/Labels.h> 00029 #include <shogun/loss/LossFunction.h> 00030 00031 namespace shogun 00032 { 00034 class CSVMSGD : public CLinearMachine 00035 { 00036 public: 00037 00039 MACHINE_PROBLEM_TYPE(PT_BINARY); 00040 00042 CSVMSGD(); 00043 00048 CSVMSGD(float64_t C); 00049 00056 CSVMSGD( 00057 float64_t C, CDotFeatures* traindat, 00058 CLabels* trainlab); 00059 00060 virtual ~CSVMSGD(); 00061 00066 virtual EMachineType get_classifier_type() { return CT_SVMSGD; } 00067 00074 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00075 00080 inline float64_t get_C1() { return C1; } 00081 00086 inline float64_t get_C2() { return C2; } 00087 00092 inline void set_epochs(int32_t e) { epochs=e; } 00093 00098 inline int32_t get_epochs() { return epochs; } 00099 00104 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00105 00110 inline bool get_bias_enabled() { return use_bias; } 00111 00116 inline void set_regularized_bias_enabled(bool enable_bias) { use_regularized_bias=enable_bias; } 00117 00122 inline bool get_regularized_bias_enabled() { return use_regularized_bias; } 00123 00128 void set_loss_function(CLossFunction* loss_func); 00129 00134 inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; } 00135 00137 virtual const char* get_name() const { return "SVMSGD"; } 00138 00139 protected: 00141 void calibrate(); 00142 00151 virtual bool train_machine(CFeatures* data=NULL); 00152 00153 private: 00154 void init(); 00155 00156 private: 00157 float64_t t; 00158 float64_t C1; 00159 float64_t C2; 00160 float64_t wscale; 00161 float64_t bscale; 00162 int32_t epochs; 00163 int32_t skip; 00164 int32_t count; 00165 00166 bool use_bias; 00167 bool use_regularized_bias; 00168 00169 CLossFunction* loss; 00170 }; 00171 } 00172 #endif