00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #ifndef _QDA_H__ 00012 #define _QDA_H__ 00013 00014 #include <shogun/lib/config.h> 00015 00016 #ifdef HAVE_LAPACK 00017 00018 #include <shogun/features/DotFeatures.h> 00019 #include <shogun/features/DenseFeatures.h> 00020 #include <shogun/machine/NativeMulticlassMachine.h> 00021 #include <shogun/lib/SGNDArray.h> 00022 00023 namespace shogun 00024 { 00025 00026 //#define DEBUG_QDA 00027 00036 class CQDA : public CNativeMulticlassMachine 00037 { 00038 public: 00039 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00040 00041 00046 CQDA(float64_t tolerance = 1e-4, bool store_covs = false); 00047 00055 CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance = 1e-4, bool store_covs = false); 00056 00057 virtual ~CQDA(); 00058 00064 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00065 00070 inline void set_store_covs(bool store_covs) { m_store_covs = store_covs; } 00071 00076 inline bool get_store_covs() { return m_store_covs; } 00077 00082 inline void set_tolerance(float64_t tolerance) { m_tolerance = tolerance; } 00083 00088 inline bool get_tolerance() { return m_tolerance; } 00089 00094 virtual EMachineType get_classifier_type() { return CT_QDA; } 00095 00100 virtual void set_features(CDotFeatures* feat) 00101 { 00102 if (feat->get_feature_class() != C_DENSE || 00103 feat->get_feature_type() != F_DREAL) 00104 SG_ERROR("QDA requires SIMPLE REAL valued features\n"); 00105 00106 SG_UNREF(m_features); 00107 SG_REF(feat); 00108 m_features = feat; 00109 } 00110 00115 virtual CDotFeatures* get_features() { SG_REF(m_features); return m_features; } 00116 00121 virtual const char* get_name() const { return "QDA"; } 00122 00129 inline SGVector< float64_t > get_mean(int32_t c) const 00130 { 00131 return SGVector< float64_t >(m_means.get_column_vector(c), m_dim, false); 00132 } 00133 00140 inline SGMatrix< float64_t > get_cov(int32_t c) const 00141 { 00142 return SGMatrix< float64_t >(m_covs.get_matrix(c), m_dim, m_dim, false); 00143 } 00144 00145 protected: 00152 virtual bool train_machine(CFeatures* data = NULL); 00153 00154 private: 00155 void init(); 00156 00157 void cleanup(); 00158 00159 private: 00161 CDotFeatures* m_features; 00162 00164 float64_t m_tolerance; 00165 00167 bool m_store_covs; 00168 00170 int32_t m_num_classes; 00171 00173 int32_t m_dim; 00174 00178 SGNDArray< float64_t > m_covs; 00179 00181 SGMatrix< float64_t > m_means; 00182 00184 SGNDArray< float64_t > m_M; 00185 00187 SGVector< float32_t > m_slog; 00188 00189 }; /* class QDA */ 00190 } /* namespace shogun */ 00191 00192 #endif /* HAVE_LAPACK */ 00193 #endif /* _QDA_H__ */