00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _LINEARCLASSIFIER_H__ 00012 #define _LINEARCLASSIFIER_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/Labels.h> 00016 #include <shogun/features/DotFeatures.h> 00017 #include <shogun/machine/Machine.h> 00018 00019 #include <stdio.h> 00020 00021 namespace shogun 00022 { 00023 class CDotFeatures; 00024 class CMachine; 00025 class CLabels; 00026 00061 class CLinearMachine : public CMachine 00062 { 00063 public: 00065 CLinearMachine(); 00066 virtual ~CLinearMachine(); 00067 00073 inline void get_w(float64_t*& dst_w, int32_t& dst_dims) 00074 { 00075 ASSERT(w && w_dim>0); 00076 dst_w=w; 00077 dst_dims=w_dim; 00078 } 00079 00084 inline SGVector<float64_t> get_w() 00085 { 00086 return SGVector<float64_t>(w, w_dim, false); 00087 } 00088 00093 inline void set_w(SGVector<float64_t> src_w) 00094 { 00095 SG_FREE(w); 00096 w=src_w.vector; 00097 w_dim=src_w.vlen; 00098 } 00099 00104 inline void set_bias(float64_t b) 00105 { 00106 bias=b; 00107 } 00108 00113 inline float64_t get_bias() 00114 { 00115 return bias; 00116 } 00117 00123 virtual bool load(FILE* srcfile); 00124 00130 virtual bool save(FILE* dstfile); 00131 00136 virtual inline void set_features(CDotFeatures* feat) 00137 { 00138 SG_UNREF(features); 00139 SG_REF(feat); 00140 features=feat; 00141 } 00142 00147 virtual CLabels* apply(); 00148 00154 virtual CLabels* apply(CFeatures* data); 00155 00157 virtual float64_t apply(int32_t vec_idx) 00158 { 00159 return features->dense_dot(vec_idx, w, w_dim) + bias; 00160 } 00161 00166 virtual CDotFeatures* get_features() { SG_REF(features); return features; } 00167 00173 virtual const char* get_name() const { return "LinearMachine"; } 00174 00175 protected: 00180 virtual void store_model_features() {} 00181 00182 protected: 00184 int32_t w_dim; 00186 float64_t* w; 00188 float64_t bias; 00190 CDotFeatures* features; 00191 }; 00192 } 00193 #endif