SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
KNN.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006 Christian Gehl
8  * Written (W) 1999-2009 Soeren Sonnenburg
9  * Written (W) 2011 Sergey Lisitsyn
10  * Written (W) 2012 Fernando José Iglesias García, cover tree support
11  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
12  */
13 
14 #ifndef _KNN_H__
15 #define _KNN_H__
16 
17 #include <stdio.h>
18 #include <shogun/lib/common.h>
19 #include <shogun/io/SGIO.h>
23 
24 namespace shogun
25 {
26 
27 class CDistanceMachine;
28 
55 class CKNN : public CDistanceMachine
56 {
57  public:
59 
60 
61  CKNN();
62 
69  CKNN(int32_t k, CDistance* d, CLabels* trainlab);
70  virtual ~CKNN();
71 
76  virtual inline EMachineType get_classifier_type() { return CT_KNN; }
77  //inline EDistanceType get_distance_type() { return DT_KNN;}
78 
84  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
85 
87  virtual float64_t apply_one(int32_t vec_idx)
88  {
89  SG_ERROR( "for performance reasons use apply() instead of apply(int32_t vec_idx)\n");
90  return 0;
91  }
92 
97 
103  virtual bool load(FILE* srcfile);
104 
110  virtual bool save(FILE* dstfile);
111 
116  inline void set_k(int32_t k)
117  {
118  ASSERT(k>0);
119  m_k=k;
120  }
121 
126  inline int32_t get_k()
127  {
128  return m_k;
129  }
130 
134  inline void set_q(float64_t q)
135  {
136  ASSERT(q<=1.0 && q>0.0);
137  m_q = q;
138  }
139 
143  inline float64_t get_q() { return m_q; }
144 
148  inline void set_use_covertree(bool use_covertree)
149  {
150  m_use_covertree = use_covertree;
151  }
152 
156  inline bool get_use_covertree() const { return m_use_covertree; }
157 
159  inline virtual const char* get_name() const { return "KNN"; }
160 
161  protected:
166  virtual void store_model_features();
167 
171  virtual CMulticlassLabels* classify_NN();
172 
176  void init_distance(CFeatures* data);
177 
186  virtual bool train_machine(CFeatures* data=NULL);
187 
188  private:
189  void init();
190 
203  int32_t choose_class(float64_t* classes, int32_t* train_lab);
204 
205  protected:
207  int32_t m_k;
208 
211 
214 
216  int32_t m_num_classes;
217 
219  int32_t m_min_label;
220 
223 };
224 
225 }
226 #endif

SHOGUN Machine Learning Toolbox - Documentation