Gaussian.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Alesis Novik
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #ifndef _GAUSSIAN_H__
00012 #define _GAUSSIAN_H__
00013 
00014 #include <shogun/lib/config.h>
00015 
00016 #ifdef HAVE_LAPACK
00017 
00018 #include <shogun/distributions/Distribution.h>
00019 #include <shogun/features/DotFeatures.h>
00020 #include <shogun/lib/common.h>
00021 #include <shogun/mathematics/lapack.h>
00022 #include <shogun/mathematics/Math.h>
00023 
00024 namespace shogun
00025 {
00026 class CDotFeatures;
00027 
00029 enum ECovType
00030 {
00032     FULL,
00034     DIAG,
00036     SPHERICAL
00037 };
00038 
00046 class CGaussian : public CDistribution
00047 {
00048     public:
00050         CGaussian();
00057         CGaussian(SGVector<float64_t> mean, SGMatrix<float64_t> cov, ECovType cov_type=FULL);
00058         virtual ~CGaussian();
00059 
00061         void init();
00062 
00069         virtual bool train(CFeatures* data=NULL);
00070 
00075         virtual int32_t get_num_model_parameters();
00076 
00082         virtual float64_t get_log_model_parameter(int32_t num_param);
00083 
00090         virtual float64_t get_log_derivative(
00091             int32_t num_param, int32_t num_example);
00092 
00100         virtual float64_t get_log_likelihood_example(int32_t num_example);
00101 
00107         virtual inline float64_t compute_PDF(SGVector<float64_t> point)
00108         {
00109             return CMath::exp(compute_log_PDF(point));
00110         }
00111 
00117         virtual float64_t compute_log_PDF(SGVector<float64_t> point);
00118 
00123         virtual inline SGVector<float64_t> get_mean()
00124         {
00125             return m_mean;
00126         }
00127 
00132         virtual inline void set_mean(SGVector<float64_t> mean)
00133         {
00134             m_mean.destroy_vector();
00135             if (mean.vlen==1)
00136                 m_cov_type=SPHERICAL;
00137 
00138             m_mean=mean;
00139         }
00140 
00145         virtual SGMatrix<float64_t> get_cov();
00146 
00153         virtual inline void set_cov(SGMatrix<float64_t> cov)
00154         {
00155             ASSERT(cov.num_rows==cov.num_cols);
00156             ASSERT(cov.num_rows==m_mean.vlen);
00157             decompose_cov(cov);
00158             init();
00159             if (cov.do_free)
00160                 cov.free_matrix();
00161         }
00162 
00167         inline ECovType get_cov_type()
00168         {
00169             return m_cov_type;
00170         }
00171 
00178         inline void set_cov_type(ECovType cov_type)
00179         {
00180             m_cov_type = cov_type;
00181         }
00182 
00187         inline SGVector<float64_t> get_d()
00188         {
00189             return m_d;
00190         }
00191 
00196         inline void set_d(SGVector<float64_t> d)
00197         {
00198             m_d.destroy_vector();
00199             m_d = d;
00200             init();
00201         }
00202 
00207         inline SGMatrix<float64_t> get_u()
00208         {
00209             return m_u;
00210         }
00211 
00216         inline void set_u(SGMatrix<float64_t> u)
00217         {
00218             m_u.destroy_matrix();
00219             m_u = u;
00220         }
00221 
00226         SGVector<float64_t> sample();
00227 
00229         inline virtual const char* get_name() const { return "Gaussian"; }
00230 
00231     private:
00233         void register_params();
00234 
00239         void decompose_cov(SGMatrix<float64_t> cov);
00240 
00241     protected:
00243         float64_t m_constant;
00245         SGVector<float64_t> m_d;
00247         SGMatrix<float64_t> m_u;
00249         SGVector<float64_t> m_mean;
00251         ECovType m_cov_type;
00252 };
00253 }
00254 #endif //HAVE_LAPACK
00255 #endif //_GAUSSIAN_H__
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation