00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _LINEARCLASSIFIER_H__ 00012 #define _LINEARCLASSIFIER_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/Labels.h> 00016 #include <shogun/features/DotFeatures.h> 00017 #include <shogun/machine/Machine.h> 00018 00019 #include <stdio.h> 00020 00021 namespace shogun 00022 { 00023 class CDotFeatures; 00024 class CMachine; 00025 class CLabels; 00026 00061 class CLinearMachine : public CMachine 00062 { 00063 public: 00065 CLinearMachine(); 00066 virtual ~CLinearMachine(); 00067 00073 inline void get_w(float64_t*& dst_w, int32_t& dst_dims) 00074 { 00075 ASSERT(w && w_dim>0); 00076 dst_w=w; 00077 dst_dims=w_dim; 00078 } 00079 00084 inline SGVector<float64_t> get_w() 00085 { 00086 return SGVector<float64_t>(w, w_dim); 00087 } 00088 00094 inline void set_w(float64_t* src_w, int32_t src_w_dim) 00095 { 00096 SG_FREE(w); 00097 w=SG_MALLOC(float64_t, src_w_dim); 00098 memcpy(w, src_w, size_t(src_w_dim)*sizeof(float64_t)); 00099 w_dim=src_w_dim; 00100 } 00101 00106 inline void set_bias(float64_t b) 00107 { 00108 bias=b; 00109 } 00110 00115 inline float64_t get_bias() 00116 { 00117 return bias; 00118 } 00119 00125 virtual bool load(FILE* srcfile); 00126 00132 virtual bool save(FILE* dstfile); 00133 00138 virtual inline void set_features(CDotFeatures* feat) 00139 { 00140 SG_UNREF(features); 00141 SG_REF(feat); 00142 features=feat; 00143 } 00144 00149 virtual CLabels* apply(); 00150 00156 virtual CLabels* apply(CFeatures* data); 00157 00159 virtual float64_t apply(int32_t vec_idx) 00160 { 00161 return features->dense_dot(vec_idx, w, w_dim) + bias; 00162 } 00163 00168 virtual CDotFeatures* get_features() { SG_REF(features); return features; } 00169 00175 virtual const char* get_name() const { return "LinearMachine"; } 00176 00177 protected: 00182 virtual void store_model_features() {} 00183 00184 protected: 00186 int32_t w_dim; 00188 float64_t* w; 00190 float64_t bias; 00192 CDotFeatures* features; 00193 }; 00194 } 00195 #endif