OnlineLinearMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _ONLINELINEARCLASSIFIER_H__
00012 #define _ONLINELINEARCLASSIFIER_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/RegressionLabels.h>
00017 #include <shogun/features/streaming/StreamingDotFeatures.h>
00018 #include <shogun/machine/Machine.h>
00019 
00020 #include <stdio.h>
00021 
00022 namespace shogun
00023 {
00050 class COnlineLinearMachine : public CMachine
00051 {
00052     public:
00054         COnlineLinearMachine();
00055         virtual ~COnlineLinearMachine();
00056 
00062         virtual void get_w(float32_t*& dst_w, int32_t& dst_dims)
00063         {
00064             ASSERT(w && w_dim>0);
00065             dst_w=w;
00066             dst_dims=w_dim;
00067         }
00068 
00075         virtual void get_w(float64_t*& dst_w, int32_t& dst_dims)
00076         {
00077             ASSERT(w && w_dim>0);
00078             dst_w=SG_MALLOC(float64_t, w_dim);
00079             for (int32_t i=0; i<w_dim; i++)
00080                 dst_w[i]=w[i];
00081             dst_dims=w_dim;
00082         }
00083 
00088         virtual SGVector<float32_t> get_w()
00089         {
00090             return SGVector<float32_t>(w, w_dim);
00091         }
00092 
00098         virtual void set_w(float32_t* src_w, int32_t src_w_dim)
00099         {
00100             SG_FREE(w);
00101             w=SG_MALLOC(float32_t, src_w_dim);
00102             memcpy(w, src_w, size_t(src_w_dim)*sizeof(float32_t));
00103             w_dim=src_w_dim;
00104         }
00105 
00112         virtual void set_w(float64_t* src_w, int32_t src_w_dim)
00113         {
00114             SG_FREE(w);
00115             w=SG_MALLOC(float32_t, src_w_dim);
00116             for (int32_t i=0; i<src_w_dim; i++)
00117                 w[i] = src_w[i];
00118             w_dim=src_w_dim;
00119         }
00120 
00125         virtual void set_bias(float32_t b)
00126         {
00127             bias=b;
00128         }
00129 
00134         virtual float32_t get_bias()
00135         {
00136             return bias;
00137         }
00138 
00143         virtual void set_features(CStreamingDotFeatures* feat)
00144         {
00145             if (features)
00146                 SG_UNREF(features);
00147             SG_REF(feat);
00148             features=feat;
00149         }
00150 
00157         virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
00158         
00165         virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00166 
00168         virtual float64_t apply_one(int32_t vec_idx)
00169         {
00170             SG_NOTIMPLEMENTED;
00171             return CMath::INFTY;
00172         }
00173 
00182         virtual float32_t apply_one(float32_t* vec, int32_t len);
00183 
00189         virtual float32_t apply_to_current_example();
00190 
00195         virtual CStreamingDotFeatures* get_features() { SG_REF(features); return features; }
00196 
00202         virtual const char* get_name() const { return "OnlineLinearMachine"; }
00203 
00207         virtual void start_train() { }
00208 
00212         virtual void stop_train() { }
00213 
00223         virtual void train_example(CStreamingDotFeatures *feature, float64_t label) { SG_NOTIMPLEMENTED; }
00224 
00225     protected:
00234         virtual bool train_machine(CFeatures* data=NULL);
00235 
00241         SGVector<float64_t> apply_get_outputs(CFeatures* data);
00242 
00244         virtual bool train_require_labels() const { return false; }
00245 
00246     protected:
00248         int32_t w_dim;
00250         float32_t* w;
00252         float32_t bias;
00254         CStreamingDotFeatures* features;
00255 };
00256 }
00257 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation