GaussianNaiveBayes.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #ifndef GAUSSIANNAIVEBAYES_H_
00012 #define GAUSSIANNAIVEBAYES_H_
00013 
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/features/DotFeatures.h>
00017 
00018 namespace shogun {
00019 
00020 class CLabels;
00021 class CDotFeatures;
00022 class CFeatures;
00023 
00035 class CGaussianNaiveBayes : public CMachine
00036 {
00037 
00038 public:
00042     CGaussianNaiveBayes();
00043 
00048     CGaussianNaiveBayes(CFeatures* train_examples, CLabels* train_labels);
00049 
00053     virtual ~CGaussianNaiveBayes();
00054 
00058     virtual inline void set_features(CDotFeatures* features)
00059     {
00060          SG_UNREF(m_features);
00061          SG_REF(features);
00062          m_features = features;
00063     }
00064 
00068     virtual inline CDotFeatures* get_features()
00069     {
00070         SG_REF(m_features);
00071         return m_features;
00072     }
00073 
00078     virtual bool train(CFeatures* data = NULL);
00079 
00083     virtual CLabels* apply();
00084 
00089     virtual CLabels* apply(CFeatures* data);
00090 
00095     virtual float64_t apply(int32_t idx);
00096 
00100     virtual inline const char* get_name() const { return "GaussianNaiveBayes"; };
00101 
00105     virtual inline EClassifierType get_classifier_type() { return CT_GAUSSIANNAIVEBAYES; };
00106 
00107 protected:
00108 
00110     CDotFeatures* m_features;
00111 
00113     int32_t m_min_label;
00114 
00116     int32_t m_num_classes;
00117 
00119     int32_t m_dim;
00120 
00122     SGVector<float64_t> m_means;
00123 
00125     SGVector<float64_t> m_variances;
00126 
00128     SGVector<float64_t> m_label_prob;
00129 
00136     float64_t inline normal_exp(float64_t x, int32_t l_idx, int32_t f_idx)
00137     {
00138         return CMath::exp(-CMath::sq(x-m_means.vector[m_dim*l_idx+f_idx])/(2*m_variances.vector[m_dim*l_idx+f_idx]));
00139     }
00140 
00142     SGVector<float64_t> m_rates;
00143 };
00144 
00145 }
00146 
00147 #endif /* GAUSSIANNAIVEBAYES_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation