GaussianNaiveBayes.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #ifndef GAUSSIANNAIVEBAYES_H_
00012 #define GAUSSIANNAIVEBAYES_H_
00013 
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/features/DotFeatures.h>
00016 
00017 namespace shogun {
00018 
00019 class CLabels;
00020 class CDotFeatures;
00021 class CFeatures;
00022 
00034 class CGaussianNaiveBayes : public CMachine
00035 {
00036 
00037 public:
00041     CGaussianNaiveBayes();
00042 
00047     CGaussianNaiveBayes(CFeatures* train_examples, CLabels* train_labels);
00048 
00052     virtual ~CGaussianNaiveBayes();
00053 
00057     virtual inline void set_features(CDotFeatures* features)
00058     {
00059          SG_UNREF(m_features);
00060          SG_REF(features);
00061          m_features = features;
00062     }
00063 
00067     virtual inline CDotFeatures* get_features()
00068     {
00069         SG_REF(m_features);
00070         return m_features;
00071     }
00072 
00077     virtual bool train(CFeatures* data = NULL);
00078 
00082     virtual CLabels* apply();
00083 
00088     virtual CLabels* apply(CFeatures* data);
00089 
00094     virtual float64_t apply(int32_t idx);
00095 
00099     virtual inline const char* get_name() const { return "GaussianNaiveBayes"; };
00100 
00104     virtual inline EClassifierType get_classifier_type() { return CT_GAUSSIANNAIVEBAYES; };
00105 
00106 protected:
00107 
00109     CDotFeatures* m_features;
00110 
00112     int32_t m_min_label;
00113 
00115     int32_t m_num_classes;
00116 
00118     int32_t m_dim;
00119 
00121     SGVector<float64_t> m_means;
00122 
00124     SGVector<float64_t> m_variances;
00125 
00127     SGVector<float64_t> m_label_prob;
00128 
00135     float64_t inline normal_exp(float64_t x, int32_t l_idx, int32_t f_idx)
00136     {
00137         return CMath::exp(-CMath::sq(x-m_means.vector[m_dim*l_idx+f_idx])/(2*m_variances.vector[m_dim*l_idx+f_idx]));
00138     }
00139 
00141     SGVector<float64_t> m_rates;
00142 };
00143 
00144 }
00145 
00146 #endif /* GAUSSIANNAIVEBAYES_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation