00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _LINEARCLASSIFIER_H__ 00012 #define _LINEARCLASSIFIER_H__ 00013 00014 #include "lib/common.h" 00015 #include "features/Labels.h" 00016 #include "features/DotFeatures.h" 00017 #include "classifier/Classifier.h" 00018 00019 #include <stdio.h> 00020 00021 namespace shogun 00022 { 00023 class CDotFeatures; 00024 class CLabels; 00025 00060 class CLinearClassifier : public CClassifier 00061 { 00062 public: 00064 CLinearClassifier(); 00065 virtual ~CLinearClassifier(); 00066 00068 virtual inline float64_t classify_example(int32_t vec_idx) 00069 { 00070 return features->dense_dot(vec_idx, w, w_dim) + bias; 00071 } 00072 00078 inline void get_w(float64_t*& dst_w, int32_t& dst_dims) 00079 { 00080 ASSERT(w && w_dim>0); 00081 dst_w=w; 00082 dst_dims=w_dim; 00083 } 00084 00090 inline void get_w(float64_t** dst_w, int32_t* dst_dims) 00091 { 00092 ASSERT(dst_w && dst_dims); 00093 ASSERT(w && w_dim>0); 00094 *dst_dims=w_dim; 00095 *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims)); 00096 ASSERT(*dst_w); 00097 memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims)); 00098 } 00099 00105 inline void set_w(float64_t* src_w, int32_t src_w_dim) 00106 { 00107 delete[] w; 00108 w=new float64_t[src_w_dim]; 00109 memcpy(w, src_w, size_t(src_w_dim)*sizeof(float64_t)); 00110 w_dim=src_w_dim; 00111 } 00112 00117 inline void set_bias(float64_t b) 00118 { 00119 bias=b; 00120 } 00121 00126 inline float64_t get_bias() 00127 { 00128 return bias; 00129 } 00130 00136 virtual bool load(FILE* srcfile); 00137 00143 virtual bool save(FILE* dstfile); 00144 00149 virtual inline void set_features(CDotFeatures* feat) 00150 { 00151 SG_UNREF(features); 00152 SG_REF(feat); 00153 features=feat; 00154 } 00155 00160 virtual CLabels* classify(); 00161 00167 virtual CLabels* classify(CFeatures* data); 00168 00173 virtual CDotFeatures* get_features() { SG_REF(features); return features; } 00174 00180 virtual const char* get_name(void) const { 00181 return "LinearClassifier"; } 00182 00183 protected: 00185 int32_t w_dim; 00187 float64_t* w; 00189 float64_t bias; 00191 CDotFeatures* features; 00192 }; 00193 } 00194 #endif