SGDQN.h

Go to the documentation of this file.
00001 #ifndef _SGDQN_H___
00002 #define _SGDQN_H___
00003 
00004 /*
00005    SVM with Quasi-Newton stochastic gradient
00006    Copyright (C) 2009- Antoine Bordes
00007 
00008    This program is free software; you can redistribute it and/or
00009    modify it under the terms of the GNU Lesser General Public
00010    License as published by the Free Software Foundation; either
00011    version 2.1 of the License, or (at your option) any later version.
00012 
00013    This program is distributed in the hope that it will be useful,
00014    but WITHOUT ANY WARRANTY; without even the implied warranty of
00015    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016    GNU General Public License for more details.
00017 
00018    You should have received a copy of the GNU General Public License
00019    along with this program; if not, write to the Free Software
00020    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA
00021 
00022    Shogun adjustments (w) 2011 Siddharth Kherada
00023 */
00024 
00025 #include <shogun/lib/common.h>
00026 #include <shogun/machine/LinearMachine.h>
00027 #include <shogun/features/DotFeatures.h>
00028 #include <shogun/labels/Labels.h>
00029 #include <shogun/loss/LossFunction.h>
00030 
00031 namespace shogun
00032 {
00034 class CSGDQN : public CLinearMachine
00035 {
00036     public:
00037 
00039         MACHINE_PROBLEM_TYPE(PT_BINARY);
00040 
00042         CSGDQN();
00043 
00048         CSGDQN(float64_t C);
00049 
00056         CSGDQN(
00057             float64_t C, CDotFeatures* traindat,
00058             CLabels* trainlab);
00059 
00060         virtual ~CSGDQN();
00061 
00066         virtual EMachineType get_classifier_type() { return CT_SGDQN; }
00067 
00076         virtual bool train(CFeatures* data=NULL);
00077 
00084         inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
00085 
00090         inline float64_t get_C1() { return C1; }
00091 
00096         inline float64_t get_C2() { return C2; }
00097 
00102         inline void set_epochs(int32_t e) { epochs=e; }
00103 
00108         inline int32_t get_epochs() { return epochs; }
00109 
00111         void compute_ratio(float64_t* W,float64_t* W_1,float64_t* B,float64_t* dst,int32_t dim,float64_t regularizer_lambda,float64_t loss);
00112 
00114         void combine_and_clip(float64_t* Bc,float64_t* B,int32_t dim,float64_t c1,float64_t c2,float64_t v1,float64_t v2);
00115 
00120         void set_loss_function(CLossFunction* loss_func);
00121 
00126         inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; }
00127 
00129         virtual const char* get_name() const { return "SGDQN"; }
00130 
00131     protected:
00133         void calibrate();
00134 
00135     private:
00136         void init();
00137 
00138     private:
00139         float64_t t;
00140         float64_t C1;
00141         float64_t C2;
00142         int32_t epochs;
00143         int32_t skip;
00144         int32_t count;
00145 
00146         CLossFunction* loss;
00147 };
00148 }
00149 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation