Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _ONLINELINEARCLASSIFIER_H__
00012 #define _ONLINELINEARCLASSIFIER_H__
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/RegressionLabels.h>
00017 #include <shogun/features/streaming/StreamingDotFeatures.h>
00018 #include <shogun/machine/Machine.h>
00019
00020 #include <stdio.h>
00021
00022 namespace shogun
00023 {
00050 class COnlineLinearMachine : public CMachine
00051 {
00052 public:
00054 COnlineLinearMachine();
00055 virtual ~COnlineLinearMachine();
00056
00062 virtual void get_w(float32_t*& dst_w, int32_t& dst_dims)
00063 {
00064 ASSERT(w && w_dim>0);
00065 dst_w=w;
00066 dst_dims=w_dim;
00067 }
00068
00075 virtual void get_w(float64_t*& dst_w, int32_t& dst_dims)
00076 {
00077 ASSERT(w && w_dim>0);
00078 dst_w=SG_MALLOC(float64_t, w_dim);
00079 for (int32_t i=0; i<w_dim; i++)
00080 dst_w[i]=w[i];
00081 dst_dims=w_dim;
00082 }
00083
00088 virtual SGVector<float32_t> get_w()
00089 {
00090 return SGVector<float32_t>(w, w_dim);
00091 }
00092
00098 virtual void set_w(float32_t* src_w, int32_t src_w_dim)
00099 {
00100 SG_FREE(w);
00101 w=SG_MALLOC(float32_t, src_w_dim);
00102 memcpy(w, src_w, size_t(src_w_dim)*sizeof(float32_t));
00103 w_dim=src_w_dim;
00104 }
00105
00112 virtual void set_w(float64_t* src_w, int32_t src_w_dim)
00113 {
00114 SG_FREE(w);
00115 w=SG_MALLOC(float32_t, src_w_dim);
00116 for (int32_t i=0; i<src_w_dim; i++)
00117 w[i] = src_w[i];
00118 w_dim=src_w_dim;
00119 }
00120
00125 virtual void set_bias(float32_t b)
00126 {
00127 bias=b;
00128 }
00129
00134 virtual float32_t get_bias()
00135 {
00136 return bias;
00137 }
00138
00143 virtual void set_features(CStreamingDotFeatures* feat)
00144 {
00145 if (features)
00146 SG_UNREF(features);
00147 SG_REF(feat);
00148 features=feat;
00149 }
00150
00157 virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
00158
00165 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00166
00168 virtual float64_t apply_one(int32_t vec_idx)
00169 {
00170 SG_NOTIMPLEMENTED;
00171 return CMath::INFTY;
00172 }
00173
00182 virtual float32_t apply_one(float32_t* vec, int32_t len);
00183
00189 virtual float32_t apply_to_current_example();
00190
00195 virtual CStreamingDotFeatures* get_features() { SG_REF(features); return features; }
00196
00202 virtual const char* get_name() const { return "OnlineLinearMachine"; }
00203
00207 virtual void start_train() { }
00208
00212 virtual void stop_train() { }
00213
00223 virtual void train_example(CStreamingDotFeatures *feature, float64_t label) { SG_NOTIMPLEMENTED; }
00224
00225 protected:
00234 virtual bool train_machine(CFeatures* data=NULL);
00235
00241 SGVector<float64_t> apply_get_outputs(CFeatures* data);
00242
00244 virtual bool train_require_labels() const { return false; }
00245
00246 protected:
00248 int32_t w_dim;
00250 float32_t* w;
00252 float32_t bias;
00254 CStreamingDotFeatures* features;
00255 };
00256 }
00257 #endif