OnlineSVMSGD.h

Go to the documentation of this file.
00001 #ifndef _ONLINESVMSGD_H___
00002 #define _ONLINESVMSGD_H___
00003 /*
00004    SVM with stochastic gradient
00005    Copyright (C) 2007- Leon Bottou
00006 
00007    This program is free software; you can redistribute it and/or
00008    modify it under the terms of the GNU Lesser General Public
00009    License as published by the Free Software Foundation; either
00010    version 2.1 of the License, or (at your option) any later version.
00011 
00012    This program is distributed in the hope that it will be useful,
00013    but WITHOUT ANY WARRANTY; without even the implied warranty of
00014    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00015    GNU General Public License for more details.
00016 
00017    You should have received a copy of the GNU General Public License
00018    along with this program; if not, write to the Free Software
00019    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
00020 
00021    Shogun adjustments (w) 2008 Soeren Sonnenburg
00022 */
00023 
00024 #include <shogun/lib/common.h>
00025 #include <shogun/labels/Labels.h>
00026 #include <shogun/machine/OnlineLinearMachine.h>
00027 #include <shogun/features/streaming/StreamingDotFeatures.h>
00028 #include <shogun/loss/LossFunction.h>
00029 
00030 namespace shogun
00031 {
00033 class COnlineSVMSGD : public COnlineLinearMachine
00034 {
00035     public:
00037         MACHINE_PROBLEM_TYPE(PT_BINARY);
00038 
00040         COnlineSVMSGD();
00041 
00046         COnlineSVMSGD(float64_t C);
00047 
00053         COnlineSVMSGD(float64_t C, CStreamingDotFeatures* traindat);
00054 
00055         virtual ~COnlineSVMSGD();
00056 
00061         virtual EMachineType get_classifier_type() { return CT_SVMSGD; }
00062 
00071         virtual bool train(CFeatures* data=NULL);
00072 
00079         inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
00080 
00085         inline float64_t get_C1() { return C1; }
00086 
00091         inline float64_t get_C2() { return C2; }
00092 
00097         inline void set_epochs(int32_t e) { epochs=e; }
00098 
00103         inline int32_t get_epochs() { return epochs; }
00104 
00109         inline void set_lambda(float64_t l) { lambda=l; }
00110 
00115         inline float64_t get_lambda() { return lambda; }
00116 
00121         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00122 
00127         inline bool get_bias_enabled() { return use_bias; }
00128 
00133         inline void set_regularized_bias_enabled(bool enable_bias) { use_regularized_bias=enable_bias; }
00134 
00139         inline bool get_regularized_bias_enabled() { return use_regularized_bias; }
00140 
00145         void set_loss_function(CLossFunction* loss_func);
00146 
00151         inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; }
00152 
00154         inline const char* get_name() const { return "OnlineSVMSGD"; }
00155 
00156     protected:
00162         void calibrate(int32_t max_vec_num=1000);
00163 
00164     private:
00165         void init();
00166 
00167     private:
00168         float64_t t;
00169         float64_t lambda;
00170         float64_t C1;
00171         float64_t C2;
00172         float64_t wscale;
00173         float64_t bscale;
00174         int32_t epochs;
00175         int32_t skip;
00176         int32_t count;
00177 
00178         bool use_bias;
00179         bool use_regularized_bias;
00180 
00181         CLossFunction* loss;
00182 };
00183 }
00184 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation