SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MCLDA.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Kevin Hughes
8  * Copyright (C) 2013 Kevin Hughes
9  *
10  * Thanks to Fernando José Iglesias García (shogun)
11  * and Matthieu Perrot (scikit-learn)
12  */
13 
14 #ifndef _MCLDA_H__
15 #define _MCLDA_H__
16 
17 #include <shogun/lib/config.h>
18 
19 #ifdef HAVE_EIGEN3
20 
24 #include <shogun/lib/SGNDArray.h>
25 
26 namespace shogun
27 {
28 
29 //#define DEBUG_MCLDA
30 
40 {
41  public:
43 
44 
49  CMCLDA(float64_t tolerance = 1e-4, bool store_cov = false);
50 
58  CMCLDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance = 1e-4, bool store_cov = false);
59 
60  virtual ~CMCLDA();
61 
67  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
68 
73  inline void set_tolerance(float64_t tolerance) { m_tolerance = tolerance; }
74 
79  inline bool get_tolerance() { return m_tolerance; }
80 
85  virtual EMachineType get_classifier_type() { return CT_LDA; } // for now add to machine typers properly later
86 
91  virtual void set_features(CDotFeatures* feat)
92  {
93  if (feat->get_feature_class() != C_DENSE ||
94  feat->get_feature_type() != F_DREAL)
95  SG_ERROR("MCLDA requires SIMPLE REAL valued features\n")
96 
97  SG_REF(feat);
98  SG_UNREF(m_features);
99  m_features = feat;
100  }
101 
106  virtual CDotFeatures* get_features() { SG_REF(m_features); return m_features; }
107 
112  virtual const char* get_name() const { return "MCLDA"; }
113 
120  inline SGVector< float64_t > get_mean(int32_t c) const
121  {
122  return SGVector< float64_t >(m_means.get_column_vector(c), m_dim, false);
123  }
124 
130  {
131  return m_cov;
132  }
133 
134  protected:
141  virtual bool train_machine(CFeatures* data = NULL);
142 
143  private:
144  void init();
145 
146  void cleanup();
147 
148  private:
150  CDotFeatures* m_features;
151 
153  float64_t m_tolerance;
154 
156  bool m_store_cov;
157 
159  int32_t m_num_classes;
160 
162  int32_t m_dim;
163 
167  SGMatrix< float64_t > m_cov;
168 
170  SGMatrix< float64_t > m_means;
171 
173  SGVector< float64_t > m_xbar;
174 
176  int32_t m_rank;
177 
179  SGMatrix< float64_t > m_scalings;
180 
182  SGMatrix< float64_t > m_coef;
183 
185  SGVector< float64_t > m_intercept;
186 
187 }; /* class MCLDA */
188 } /* namespace shogun */
189 
190 #endif /* HAVE_EIGEN3 */
191 #endif /* _MCLDA_H__ */

SHOGUN Machine Learning Toolbox - Documentation