00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006 Christian Gehl 00008 * Written (W) 1999-2009 Soeren Sonnenburg 00009 * Written (W) 2011 Sergey Lisitsyn 00010 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00011 */ 00012 00013 #ifndef _KNN_H__ 00014 #define _KNN_H__ 00015 00016 #include <stdio.h> 00017 #include <shogun/lib/common.h> 00018 #include <shogun/io/SGIO.h> 00019 #include <shogun/features/Features.h> 00020 #include <shogun/distance/Distance.h> 00021 #include <shogun/machine/DistanceMachine.h> 00022 00023 namespace shogun 00024 { 00025 class CDistanceMachine; 00026 00053 class CKNN : public CDistanceMachine 00054 { 00055 public: 00057 CKNN(); 00058 00065 CKNN(int32_t k, CDistance* d, CLabels* trainlab); 00066 virtual ~CKNN(); 00067 00072 virtual inline EClassifierType get_classifier_type() { return CT_KNN; } 00073 //inline EDistanceType get_distance_type() { return DT_KNN;} 00074 00079 virtual CLabels* apply(); 00080 00086 virtual CLabels* apply(CFeatures* data); 00087 00089 virtual float64_t apply(int32_t vec_idx) 00090 { 00091 SG_ERROR( "for performance reasons use apply() instead of apply(int32_t vec_idx)\n"); 00092 return 0; 00093 } 00094 00098 SGMatrix<int32_t> classify_for_multiple_k(); 00099 00105 virtual bool load(FILE* srcfile); 00106 00112 virtual bool save(FILE* dstfile); 00113 00118 inline void set_k(int32_t k) 00119 { 00120 ASSERT(k>0); 00121 m_k=k; 00122 } 00123 00128 inline int32_t get_k() 00129 { 00130 return m_k; 00131 } 00132 00136 inline void set_q(float64_t q) 00137 { 00138 ASSERT(q<=1.0 && q>0.0); 00139 m_q = q; 00140 } 00141 00145 inline float64_t get_q() { return m_q; } 00146 00148 inline virtual const char* get_name() const { return "KNN"; } 00149 00150 protected: 00155 virtual void store_model_features(); 00156 00160 virtual CLabels* classify_NN(); 00161 00165 void init_distance(CFeatures* data); 00166 00175 virtual bool train_machine(CFeatures* data=NULL); 00176 00177 private: 00178 void init(); 00179 00180 protected: 00182 int32_t m_k; 00183 00185 float64_t m_q; 00186 00188 int32_t num_classes; 00189 00191 int32_t min_label; 00192 00194 SGVector<int32_t> train_labels; 00195 }; 00196 } 00197 #endif