LinearMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/features/Labels.h>
00016 #include <shogun/features/DotFeatures.h>
00017 #include <shogun/machine/Machine.h>
00018 
00019 #include <stdio.h>
00020 
00021 namespace shogun
00022 {
00023     class CDotFeatures;
00024     class CMachine;
00025     class CLabels;
00026 
00061 class CLinearMachine : public CMachine
00062 {
00063     public:
00065         CLinearMachine();
00066         virtual ~CLinearMachine();
00067 
00073         inline void get_w(float64_t*& dst_w, int32_t& dst_dims)
00074         {
00075             ASSERT(w && w_dim>0);
00076             dst_w=w;
00077             dst_dims=w_dim;
00078         }
00079 
00084         inline SGVector<float64_t> get_w()
00085         {
00086             return SGVector<float64_t>(w, w_dim, false);
00087         }
00088 
00093         inline void set_w(SGVector<float64_t> src_w)
00094         {
00095             SG_FREE(w);
00096             w=src_w.vector;
00097             w_dim=src_w.vlen;
00098         }
00099 
00104         inline void set_bias(float64_t b)
00105         {
00106             bias=b;
00107         }
00108 
00113         inline float64_t get_bias()
00114         {
00115             return bias;
00116         }
00117 
00123         virtual bool load(FILE* srcfile);
00124 
00130         virtual bool save(FILE* dstfile);
00131 
00136         virtual inline void set_features(CDotFeatures* feat)
00137         {
00138             SG_UNREF(features);
00139             SG_REF(feat);
00140             features=feat;
00141         }
00142 
00147         virtual CLabels* apply();
00148 
00154         virtual CLabels* apply(CFeatures* data);
00155 
00157         virtual float64_t apply(int32_t vec_idx)
00158         {
00159             return features->dense_dot(vec_idx, w, w_dim) + bias;
00160         }
00161 
00166         virtual CDotFeatures* get_features() { SG_REF(features); return features; }
00167 
00173         virtual const char* get_name() const { return "LinearMachine"; }
00174 
00175     protected:
00180         virtual void store_model_features() {}
00181 
00182     protected:
00184         int32_t w_dim;
00186         float64_t* w;
00188         float64_t bias;
00190         CDotFeatures* features;
00191 };
00192 }
00193 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation