Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef _GMM_H__
00011 #define _GMM_H__
00012
00013 #include <shogun/lib/config.h>
00014
00015 #ifdef HAVE_LAPACK
00016
00017 #include <shogun/distributions/Distribution.h>
00018 #include <shogun/distributions/Gaussian.h>
00019 #include <shogun/lib/common.h>
00020
00021 #include <vector>
00022
00023 using namespace std;
00024
00025 namespace shogun
00026 {
00040 class CGMM : public CDistribution
00041 {
00042 public:
00044 CGMM();
00050 CGMM(int32_t n, ECovType cov_type=FULL);
00057 CGMM(vector<CGaussian*> components, SGVector<float64_t> coefficients,
00058 bool copy=false);
00059 virtual ~CGMM();
00060
00062 void cleanup();
00063
00070 virtual bool train(CFeatures* data=NULL);
00071
00080 float64_t train_em(float64_t min_cov=1e-9, int32_t max_iter=1000,
00081 float64_t min_change=1e-9);
00082
00093 float64_t train_smem(int32_t max_iter=100, int32_t max_cand=5,
00094 float64_t min_cov=1e-9, int32_t max_em_iter=1000,
00095 float64_t min_change=1e-9);
00096
00102 void max_likelihood(SGMatrix<float64_t> alpha, float64_t min_cov);
00103
00108 virtual int32_t get_num_model_parameters();
00109
00115 virtual float64_t get_log_model_parameter(int32_t num_param);
00116
00123 virtual float64_t get_log_derivative(
00124 int32_t num_param, int32_t num_example);
00125
00133 virtual float64_t get_log_likelihood_example(int32_t num_example);
00134
00142 virtual float64_t get_likelihood_example(int32_t num_example);
00143
00150 virtual SGVector<float64_t> get_nth_mean(int32_t num);
00151
00157 virtual void set_nth_mean(SGVector<float64_t> mean, int32_t num);
00158
00165 virtual SGMatrix<float64_t> get_nth_cov(int32_t num);
00166
00172 virtual void set_nth_cov(SGMatrix<float64_t> cov, int32_t num);
00173
00178 virtual SGVector<float64_t> get_coef();
00179
00184 virtual void set_coef(const SGVector<float64_t> coefficients);
00185
00190 virtual vector<CGaussian*> get_comp();
00191
00196 virtual void set_comp(vector<CGaussian*> components);
00197
00202 SGVector<float64_t> sample();
00203
00209 SGVector<float64_t> cluster(SGVector<float64_t> point);
00210
00212 virtual const char* get_name() const { return "GMM"; }
00213
00214 private:
00221 SGMatrix<float64_t> alpha_init(SGMatrix<float64_t> init_means);
00222
00224 void register_params();
00225
00235 void partial_em(int32_t comp1, int32_t comp2, int32_t comp3,
00236 float64_t min_cov, int32_t max_em_iter, float64_t min_change);
00237
00238 protected:
00240 vector<CGaussian*> m_components;
00242 SGVector<float64_t> m_coefficients;
00243 };
00244 }
00245 #endif //HAVE_LAPACK
00246 #endif //_GMM_H__