00001 #ifndef _ONLINESVMSGD_H___ 00002 #define _ONLINESVMSGD_H___ 00003 /* 00004 SVM with stochastic gradient 00005 Copyright (C) 2007- Leon Bottou 00006 00007 This program is free software; you can redistribute it and/or 00008 modify it under the terms of the GNU Lesser General Public 00009 License as published by the Free Software Foundation; either 00010 version 2.1 of the License, or (at your option) any later version. 00011 00012 This program is distributed in the hope that it will be useful, 00013 but WITHOUT ANY WARRANTY; without even the implied warranty of 00014 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00015 GNU General Public License for more details. 00016 00017 You should have received a copy of the GNU General Public License 00018 along with this program; if not, write to the Free Software 00019 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00020 00021 Shogun adjustments (w) 2008 Soeren Sonnenburg 00022 */ 00023 00024 #include <shogun/lib/common.h> 00025 #include <shogun/labels/Labels.h> 00026 #include <shogun/machine/OnlineLinearMachine.h> 00027 #include <shogun/features/streaming/StreamingDotFeatures.h> 00028 #include <shogun/loss/LossFunction.h> 00029 00030 namespace shogun 00031 { 00033 class COnlineSVMSGD : public COnlineLinearMachine 00034 { 00035 public: 00037 MACHINE_PROBLEM_TYPE(PT_BINARY); 00038 00040 COnlineSVMSGD(); 00041 00046 COnlineSVMSGD(float64_t C); 00047 00053 COnlineSVMSGD(float64_t C, CStreamingDotFeatures* traindat); 00054 00055 virtual ~COnlineSVMSGD(); 00056 00061 virtual EMachineType get_classifier_type() { return CT_SVMSGD; } 00062 00071 virtual bool train(CFeatures* data=NULL); 00072 00079 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00080 00085 inline float64_t get_C1() { return C1; } 00086 00091 inline float64_t get_C2() { return C2; } 00092 00097 inline void set_epochs(int32_t e) { epochs=e; } 00098 00103 inline int32_t get_epochs() { return epochs; } 00104 00109 inline void set_lambda(float64_t l) { lambda=l; } 00110 00115 inline float64_t get_lambda() { return lambda; } 00116 00121 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00122 00127 inline bool get_bias_enabled() { return use_bias; } 00128 00133 inline void set_regularized_bias_enabled(bool enable_bias) { use_regularized_bias=enable_bias; } 00134 00139 inline bool get_regularized_bias_enabled() { return use_regularized_bias; } 00140 00145 void set_loss_function(CLossFunction* loss_func); 00146 00151 inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; } 00152 00154 inline const char* get_name() const { return "OnlineSVMSGD"; } 00155 00156 protected: 00162 void calibrate(int32_t max_vec_num=1000); 00163 00164 private: 00165 void init(); 00166 00167 private: 00168 float64_t t; 00169 float64_t lambda; 00170 float64_t C1; 00171 float64_t C2; 00172 float64_t wscale; 00173 float64_t bscale; 00174 int32_t epochs; 00175 int32_t skip; 00176 int32_t count; 00177 00178 bool use_bias; 00179 bool use_regularized_bias; 00180 00181 CLossFunction* loss; 00182 }; 00183 } 00184 #endif