00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006 Christian Gehl 00008 * Written (W) 1999-2009 Soeren Sonnenburg 00009 * Written (W) 2011 Sergey Lisitsyn 00010 * Written (W) 2012 Fernando José Iglesias García, cover tree support 00011 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00012 */ 00013 00014 #ifndef _KNN_H__ 00015 #define _KNN_H__ 00016 00017 #include <stdio.h> 00018 #include <shogun/lib/common.h> 00019 #include <shogun/io/SGIO.h> 00020 #include <shogun/features/Features.h> 00021 #include <shogun/distance/Distance.h> 00022 #include <shogun/machine/DistanceMachine.h> 00023 00024 namespace shogun 00025 { 00026 00027 class CDistanceMachine; 00028 00055 class CKNN : public CDistanceMachine 00056 { 00057 public: 00058 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00059 00060 00061 CKNN(); 00062 00069 CKNN(int32_t k, CDistance* d, CLabels* trainlab); 00070 virtual ~CKNN(); 00071 00076 virtual EMachineType get_classifier_type() { return CT_KNN; } 00077 //inline EDistanceType get_distance_type() { return DT_KNN;} 00078 00084 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00085 00087 virtual float64_t apply_one(int32_t vec_idx) 00088 { 00089 SG_ERROR( "for performance reasons use apply() instead of apply(int32_t vec_idx)\n"); 00090 return 0; 00091 } 00092 00096 SGMatrix<int32_t> classify_for_multiple_k(); 00097 00103 virtual bool load(FILE* srcfile); 00104 00110 virtual bool save(FILE* dstfile); 00111 00116 inline void set_k(int32_t k) 00117 { 00118 ASSERT(k>0); 00119 m_k=k; 00120 } 00121 00126 inline int32_t get_k() 00127 { 00128 return m_k; 00129 } 00130 00134 inline void set_q(float64_t q) 00135 { 00136 ASSERT(q<=1.0 && q>0.0); 00137 m_q = q; 00138 } 00139 00143 inline float64_t get_q() { return m_q; } 00144 00148 inline void set_use_covertree(bool use_covertree) 00149 { 00150 m_use_covertree = use_covertree; 00151 } 00152 00156 inline bool get_use_covertree() const { return m_use_covertree; } 00157 00159 virtual const char* get_name() const { return "KNN"; } 00160 00161 protected: 00166 virtual void store_model_features(); 00167 00171 virtual CMulticlassLabels* classify_NN(); 00172 00176 void init_distance(CFeatures* data); 00177 00186 virtual bool train_machine(CFeatures* data=NULL); 00187 00188 private: 00189 void init(); 00190 00203 int32_t choose_class(float64_t* classes, int32_t* train_lab); 00204 00205 protected: 00207 int32_t m_k; 00208 00210 float64_t m_q; 00211 00213 bool m_use_covertree; 00214 00216 int32_t m_num_classes; 00217 00219 int32_t m_min_label; 00220 00222 SGVector<int32_t> m_train_labels; 00223 }; 00224 00225 } 00226 #endif