QDA.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #ifndef _QDA_H__
00012 #define _QDA_H__
00013 
00014 #include <shogun/lib/config.h>
00015 
00016 #ifdef HAVE_LAPACK
00017 
00018 #include <shogun/features/DotFeatures.h>
00019 #include <shogun/features/DenseFeatures.h>
00020 #include <shogun/machine/NativeMulticlassMachine.h>
00021 #include <shogun/lib/SGNDArray.h>
00022 
00023 namespace shogun
00024 {
00025 
00026 //#define DEBUG_QDA
00027 
00036 class CQDA : public CNativeMulticlassMachine
00037 {
00038     public:
00039         MACHINE_PROBLEM_TYPE(PT_MULTICLASS)
00040 
00041         
00046         CQDA(float64_t tolerance = 1e-4, bool store_covs = false);
00047 
00055         CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance = 1e-4, bool store_covs = false);
00056 
00057         virtual ~CQDA();
00058 
00064         virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00065 
00070         inline void set_store_covs(bool store_covs) { m_store_covs = store_covs; }
00071 
00076         inline bool get_store_covs() { return m_store_covs; }
00077 
00082         inline void set_tolerance(float64_t tolerance) { m_tolerance = tolerance; }
00083 
00088         inline bool get_tolerance() { return m_tolerance; }
00089 
00094         virtual EMachineType get_classifier_type() { return CT_QDA; }
00095 
00100         virtual void set_features(CDotFeatures* feat)
00101         {
00102             if (feat->get_feature_class() != C_DENSE ||
00103                 feat->get_feature_type() != F_DREAL)
00104                 SG_ERROR("QDA requires SIMPLE REAL valued features\n");
00105 
00106             SG_UNREF(m_features);
00107             SG_REF(feat);
00108             m_features = feat;
00109         }
00110 
00115         virtual CDotFeatures* get_features() { SG_REF(m_features); return m_features; }
00116 
00121         virtual const char* get_name() const { return "QDA"; }
00122 
00129         inline SGVector< float64_t > get_mean(int32_t c) const
00130         {
00131             return SGVector< float64_t >(m_means.get_column_vector(c), m_dim, false);
00132         }
00133 
00140         inline SGMatrix< float64_t > get_cov(int32_t c) const
00141         {
00142             return SGMatrix< float64_t >(m_covs.get_matrix(c), m_dim, m_dim, false);
00143         }
00144 
00145     protected:
00152         virtual bool train_machine(CFeatures* data = NULL);
00153 
00154     private:
00155         void init();
00156 
00157         void cleanup();
00158 
00159     private:
00161         CDotFeatures* m_features;
00162 
00164         float64_t m_tolerance;
00165 
00167         bool m_store_covs;
00168 
00170         int32_t m_num_classes;
00171 
00173         int32_t m_dim;
00174 
00178         SGNDArray< float64_t > m_covs;
00179 
00181         SGMatrix< float64_t > m_means;
00182 
00184         SGNDArray< float64_t > m_M;
00185 
00187         SGVector< float32_t > m_slog;
00188 
00189 }; /* class QDA */
00190 }  /* namespace shogun */
00191 
00192 #endif /* HAVE_LAPACK */
00193 #endif /* _QDA_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation