LinearMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/features/Labels.h>
00016 #include <shogun/features/DotFeatures.h>
00017 #include <shogun/machine/Machine.h>
00018 
00019 #include <stdio.h>
00020 
00021 namespace shogun
00022 {
00023     class CDotFeatures;
00024     class CMachine;
00025     class CLabels;
00026 
00061 class CLinearMachine : public CMachine
00062 {
00063     public:
00065         CLinearMachine();
00066         virtual ~CLinearMachine();
00067 
00073         inline void get_w(float64_t*& dst_w, int32_t& dst_dims)
00074         {
00075             ASSERT(w && w_dim>0);
00076             dst_w=w;
00077             dst_dims=w_dim;
00078         }
00079 
00084         inline SGVector<float64_t> get_w()
00085         {
00086             return SGVector<float64_t>(w, w_dim);
00087         }
00088 
00094         inline void set_w(float64_t* src_w, int32_t src_w_dim)
00095         {
00096             SG_FREE(w);
00097             w=SG_MALLOC(float64_t, src_w_dim);
00098             memcpy(w, src_w, size_t(src_w_dim)*sizeof(float64_t));
00099             w_dim=src_w_dim;
00100         }
00101 
00106         inline void set_bias(float64_t b)
00107         {
00108             bias=b;
00109         }
00110 
00115         inline float64_t get_bias()
00116         {
00117             return bias;
00118         }
00119 
00125         virtual bool load(FILE* srcfile);
00126 
00132         virtual bool save(FILE* dstfile);
00133 
00138         virtual inline void set_features(CDotFeatures* feat)
00139         {
00140             SG_UNREF(features);
00141             SG_REF(feat);
00142             features=feat;
00143         }
00144 
00149         virtual CLabels* apply();
00150 
00156         virtual CLabels* apply(CFeatures* data);
00157 
00159         virtual float64_t apply(int32_t vec_idx)
00160         {
00161             return features->dense_dot(vec_idx, w, w_dim) + bias;
00162         }
00163 
00168         virtual CDotFeatures* get_features() { SG_REF(features); return features; }
00169 
00175         virtual const char* get_name() const { return "LinearMachine"; }
00176 
00177     protected:
00182         virtual void store_model_features() {}
00183 
00184     protected:
00186         int32_t w_dim;
00188         float64_t* w;
00190         float64_t bias;
00192         CDotFeatures* features;
00193 };
00194 }
00195 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation