Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _ONLINELINEARCLASSIFIER_H__
00012 #define _ONLINELINEARCLASSIFIER_H__
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/features/Labels.h>
00016 #include <shogun/features/StreamingDotFeatures.h>
00017 #include <shogun/machine/Machine.h>
00018
00019 #include <stdio.h>
00020
00021 namespace shogun
00022 {
00049 class COnlineLinearMachine : public CMachine
00050 {
00051 public:
00053 COnlineLinearMachine();
00054 virtual ~COnlineLinearMachine();
00055
00061 virtual inline void get_w(float32_t*& dst_w, int32_t& dst_dims)
00062 {
00063 ASSERT(w && w_dim>0);
00064 dst_w=w;
00065 dst_dims=w_dim;
00066 }
00067
00074 virtual void get_w(float64_t*& dst_w, int32_t& dst_dims)
00075 {
00076 ASSERT(w && w_dim>0);
00077 dst_w=SG_MALLOC(float64_t, w_dim);
00078 for (int32_t i=0; i<w_dim; i++)
00079 dst_w[i]=w[i];
00080 dst_dims=w_dim;
00081 }
00082
00087 virtual inline SGVector<float32_t> get_w()
00088 {
00089 return SGVector<float32_t>(w, w_dim);
00090 }
00091
00097 virtual inline void set_w(float32_t* src_w, int32_t src_w_dim)
00098 {
00099 SG_FREE(w);
00100 w=SG_MALLOC(float32_t, src_w_dim);
00101 memcpy(w, src_w, size_t(src_w_dim)*sizeof(float32_t));
00102 w_dim=src_w_dim;
00103 }
00104
00111 virtual void set_w(float64_t* src_w, int32_t src_w_dim)
00112 {
00113 SG_FREE(w);
00114 w=SG_MALLOC(float32_t, src_w_dim);
00115 for (int32_t i=0; i<src_w_dim; i++)
00116 w[i] = src_w[i];
00117 w_dim=src_w_dim;
00118 }
00119
00124 virtual inline void set_bias(float32_t b)
00125 {
00126 bias=b;
00127 }
00128
00133 virtual inline float32_t get_bias()
00134 {
00135 return bias;
00136 }
00137
00143 virtual bool load(FILE* srcfile);
00144
00150 virtual bool save(FILE* dstfile);
00151
00156 virtual inline void set_features(CStreamingDotFeatures* feat)
00157 {
00158 if (features)
00159 SG_UNREF(features);
00160 SG_REF(feat);
00161 features=feat;
00162 }
00163
00168 virtual CLabels* apply();
00169
00175 virtual CLabels* apply(CFeatures* data);
00176
00178 virtual float64_t apply(int32_t vec_idx)
00179 {
00180 SG_NOTIMPLEMENTED;
00181 return CMath::INFTY;
00182 }
00183
00192 virtual float32_t apply(float32_t* vec, int32_t len);
00193
00199 virtual float32_t apply_to_current_example();
00200
00205 virtual CStreamingDotFeatures* get_features() { SG_REF(features); return features; }
00206
00212 virtual const char* get_name() const { return "OnlineLinearMachine"; }
00213
00214 protected:
00216 int32_t w_dim;
00218 float32_t* w;
00220 float32_t bias;
00222 CStreamingDotFeatures* features;
00223 };
00224 }
00225 #endif