00001 #ifndef _ONLINESVMSGD_H___ 00002 #define _ONLINESVMSGD_H___ 00003 /* 00004 SVM with stochastic gradient 00005 Copyright (C) 2007- Leon Bottou 00006 00007 This program is free software; you can redistribute it and/or 00008 modify it under the terms of the GNU Lesser General Public 00009 License as published by the Free Software Foundation; either 00010 version 2.1 of the License, or (at your option) any later version. 00011 00012 This program is distributed in the hope that it will be useful, 00013 but WITHOUT ANY WARRANTY; without even the implied warranty of 00014 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00015 GNU General Public License for more details. 00016 00017 You should have received a copy of the GNU General Public License 00018 along with this program; if not, write to the Free Software 00019 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00020 00021 Shogun adjustments (w) 2008 Soeren Sonnenburg 00022 */ 00023 00024 #include <shogun/lib/common.h> 00025 #include <shogun/features/Labels.h> 00026 #include <shogun/machine/OnlineLinearMachine.h> 00027 #include <shogun/features/StreamingDotFeatures.h> 00028 #include <shogun/loss/LossFunction.h> 00029 00030 namespace shogun 00031 { 00033 class COnlineSVMSGD : public COnlineLinearMachine 00034 { 00035 public: 00037 COnlineSVMSGD(); 00038 00043 COnlineSVMSGD(float64_t C); 00044 00050 COnlineSVMSGD(float64_t C, CStreamingDotFeatures* traindat); 00051 00052 virtual ~COnlineSVMSGD(); 00053 00058 virtual inline EClassifierType get_classifier_type() { return CT_SVMSGD; } 00059 00068 virtual bool train(CFeatures* data=NULL); 00069 00076 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00077 00082 inline float64_t get_C1() { return C1; } 00083 00088 inline float64_t get_C2() { return C2; } 00089 00094 inline void set_epochs(int32_t e) { epochs=e; } 00095 00100 inline int32_t get_epochs() { return epochs; } 00101 00106 inline void set_lambda(float64_t l) { lambda=l; } 00107 00112 inline float64_t get_lambda() { return lambda; } 00113 00118 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00119 00124 inline bool get_bias_enabled() { return use_bias; } 00125 00130 inline void set_regularized_bias_enabled(bool enable_bias) { use_regularized_bias=enable_bias; } 00131 00136 inline bool get_regularized_bias_enabled() { return use_regularized_bias; } 00137 00142 void set_loss_function(CLossFunction* loss_func); 00143 00148 inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; } 00149 00151 inline virtual const char* get_name() const { return "OnlineSVMSGD"; } 00152 00153 protected: 00159 void calibrate(int32_t max_vec_num=1000); 00160 00161 private: 00162 void init(); 00163 00164 private: 00165 float64_t t; 00166 float64_t lambda; 00167 float64_t C1; 00168 float64_t C2; 00169 float64_t wscale; 00170 float64_t bscale; 00171 int32_t epochs; 00172 int32_t skip; 00173 int32_t count; 00174 00175 bool use_bias; 00176 bool use_regularized_bias; 00177 00178 CLossFunction* loss; 00179 }; 00180 } 00181 #endif