SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
KNN.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2006 Christian Gehl
8  * Written (W) 1999-2009 Soeren Sonnenburg
9  * Written (W) 2011 Sergey Lisitsyn
10  * Written (W) 2012 Fernando José Iglesias García, cover tree support
11  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
12  */
13 
14 #ifndef _KNN_H__
15 #define _KNN_H__
16 
17 #include <shogun/lib/config.h>
18 
19 #include <shogun/lib/common.h>
20 #include <shogun/io/SGIO.h>
24 
25 namespace shogun
26 {
27 
28 class CDistanceMachine;
29 
59 class CKNN : public CDistanceMachine
60 {
61  public:
63 
64 
65  CKNN();
66 
73  CKNN(int32_t k, CDistance* d, CLabels* trainlab);
74  virtual ~CKNN();
75 
80  virtual EMachineType get_classifier_type() { return CT_KNN; }
81 
92 
98  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
99 
101  virtual float64_t apply_one(int32_t vec_idx)
102  {
103  SG_ERROR("for performance reasons use apply() instead of apply(int32_t vec_idx)\n")
104  return 0;
105  }
106 
111 
117  virtual bool load(FILE* srcfile);
118 
124  virtual bool save(FILE* dstfile);
125 
130  inline void set_k(int32_t k)
131  {
132  ASSERT(k>0)
133  m_k=k;
134  }
135 
140  inline int32_t get_k()
141  {
142  return m_k;
143  }
144 
148  inline void set_q(float64_t q)
149  {
150  ASSERT(q<=1.0 && q>0.0)
151  m_q = q;
152  }
153 
157  inline float64_t get_q() { return m_q; }
158 
162  inline void set_use_covertree(bool use_covertree)
163  {
164  m_use_covertree = use_covertree;
165  }
166 
170  inline bool get_use_covertree() const { return m_use_covertree; }
171 
173  virtual const char* get_name() const { return "KNN"; }
174 
175  protected:
180  virtual void store_model_features();
181 
185  virtual CMulticlassLabels* classify_NN();
186 
190  void init_distance(CFeatures* data);
191 
200  virtual bool train_machine(CFeatures* data=NULL);
201 
202  private:
203  void init();
204 
217  int32_t choose_class(float64_t* classes, int32_t* train_lab);
218 
231  void choose_class_for_multiple_k(int32_t* output, int32_t* classes, int32_t* train_lab, int32_t step);
232 
233  protected:
235  int32_t m_k;
236 
239 
242 
244  int32_t m_num_classes;
245 
247  int32_t m_min_label;
248 
251 };
252 
253 }
254 #endif
EMachineType
Definition: Machine.h:33
virtual void store_model_features()
Definition: KNN.cpp:450
virtual bool save(FILE *dstfile)
Definition: KNN.cpp:443
virtual EMachineType get_classifier_type()
Definition: KNN.h:80
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
void init_distance(CFeatures *data)
Definition: KNN.cpp:422
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
float64_t get_q()
Definition: KNN.h:157
SGMatrix< int32_t > classify_for_multiple_k()
Definition: KNN.cpp:333
#define SG_ERROR(...)
Definition: SGIO.h:129
int32_t get_k()
Definition: KNN.h:140
int32_t m_min_label
smallest label, i.e. -1
Definition: KNN.h:247
virtual bool train_machine(CFeatures *data=NULL)
Definition: KNN.cpp:72
void set_q(float64_t q)
Definition: KNN.h:148
SGMatrix< index_t > nearest_neighbors()
Definition: KNN.cpp:109
A generic DistanceMachine interface.
virtual bool load(FILE *srcfile)
Definition: KNN.cpp:436
int32_t m_num_classes
number of classes (i.e. number of values labels can take)
Definition: KNN.h:244
Multiclass Labels for multi-class classification.
int32_t m_k
the k parameter in KNN
Definition: KNN.h:235
#define ASSERT(x)
Definition: SGIO.h:201
void set_use_covertree(bool use_covertree)
Definition: KNN.h:162
#define MACHINE_PROBLEM_TYPE(PT)
Definition: Machine.h:120
double float64_t
Definition: common.h:50
Class KNN, an implementation of the standard k-nearest neigbor classifier.
Definition: KNN.h:59
float64_t m_q
parameter q of rank weighting
Definition: KNN.h:238
SGVector< int32_t > m_train_labels
Definition: KNN.h:250
bool get_use_covertree() const
Definition: KNN.h:170
void set_k(int32_t k)
Definition: KNN.h:130
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual const char * get_name() const
Definition: KNN.h:173
virtual ~CKNN()
Definition: KNN.cpp:68
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual CMulticlassLabels * classify_NN()
Definition: KNN.cpp:288
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: KNN.cpp:153
virtual float64_t apply_one(int32_t vec_idx)
get output for example "vec_idx"
Definition: KNN.h:101
bool m_use_covertree
parameter to enable cover tree support
Definition: KNN.h:241

SHOGUN Machine Learning Toolbox - Documentation