00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/GaussianNaiveBayes.h>
00012 #include <shogun/features/Features.h>
00013 #include <shogun/labels/Labels.h>
00014 #include <shogun/labels/RegressionLabels.h>
00015 #include <shogun/labels/MulticlassLabels.h>
00016 #include <shogun/mathematics/Math.h>
00017 #include <shogun/lib/Signal.h>
00018
00019 using namespace shogun;
00020
00021 CGaussianNaiveBayes::CGaussianNaiveBayes() : CNativeMulticlassMachine(), m_features(NULL),
00022 m_min_label(0), m_num_classes(0), m_dim(0), m_means(), m_variances(),
00023 m_label_prob(), m_rates()
00024 {
00025
00026 };
00027
00028 CGaussianNaiveBayes::CGaussianNaiveBayes(CFeatures* train_examples,
00029 CLabels* train_labels) : CNativeMulticlassMachine(), m_features(NULL),
00030 m_min_label(0), m_num_classes(0), m_dim(0), m_means(),
00031 m_variances(), m_label_prob(), m_rates()
00032 {
00033 ASSERT(train_examples->get_num_vectors() == train_labels->get_num_labels());
00034 set_labels(train_labels);
00035
00036 if (!train_examples->has_property(FP_DOT))
00037 SG_ERROR("Specified features are not of type CDotFeatures\n");
00038
00039 set_features((CDotFeatures*)train_examples);
00040 };
00041
00042 CGaussianNaiveBayes::~CGaussianNaiveBayes()
00043 {
00044 SG_UNREF(m_features);
00045 };
00046
00047 CFeatures* CGaussianNaiveBayes::get_features()
00048 {
00049 SG_REF(m_features);
00050 return m_features;
00051 }
00052
00053 void CGaussianNaiveBayes::set_features(CFeatures* features)
00054 {
00055 if (!features->has_property(FP_DOT))
00056 SG_ERROR("Specified features are not of type CDotFeatures\n");
00057
00058 SG_UNREF(m_features);
00059 SG_REF(features);
00060 m_features = (CDotFeatures*)features;
00061 }
00062
00063 bool CGaussianNaiveBayes::train_machine(CFeatures* data)
00064 {
00065
00066 if (data)
00067 {
00068 if (!data->has_property(FP_DOT))
00069 SG_ERROR("Specified features are not of type CDotFeatures\n");
00070 set_features((CDotFeatures*) data);
00071 }
00072
00073
00074 ASSERT(m_labels);
00075 ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00076 SGVector<int32_t> train_labels = ((CMulticlassLabels*) m_labels)->get_int_labels();
00077 ASSERT(m_features->get_num_vectors()==train_labels.vlen);
00078
00079
00080 int32_t min_label = train_labels.vector[0];
00081 int32_t max_label = train_labels.vector[0];
00082 int i,j;
00083
00084
00085 for (i=1; i<train_labels.vlen; i++)
00086 {
00087 min_label = CMath::min(min_label, train_labels.vector[i]);
00088 max_label = CMath::max(max_label, train_labels.vector[i]);
00089 }
00090
00091
00092 for (i=0; i<train_labels.vlen; i++)
00093 train_labels.vector[i]-= min_label;
00094
00095
00096 m_num_classes = max_label-min_label+1;
00097 m_min_label = min_label;
00098 m_dim = m_features->get_dim_feature_space();
00099
00100
00101 m_means=SGMatrix<float64_t>(m_dim,m_num_classes);
00102 m_variances=SGMatrix<float64_t>(m_dim, m_num_classes);
00103 m_label_prob=SGVector<float64_t>(m_num_classes);
00104
00105
00106 m_rates=SGVector<float64_t>(m_num_classes);
00107
00108
00109 m_means.zero();
00110 m_variances.zero();
00111 m_label_prob.zero();
00112 m_rates.zero();
00113
00114
00115 int32_t max_progress = 2 * train_labels.vlen + 2 * m_num_classes;
00116
00117
00118 int32_t progress = 0;
00119 SG_PROGRESS(progress, 0, max_progress);
00120
00121
00122 for (i=0; i<train_labels.vlen; i++)
00123 {
00124 SGVector<float64_t> fea = m_features->get_computed_dot_feature_vector(i);
00125 for (j=0; j<m_dim; j++)
00126 m_means(j, train_labels.vector[i]) += fea.vector[j];
00127
00128 m_label_prob.vector[train_labels.vector[i]]+=1.0;
00129
00130 progress++;
00131 SG_PROGRESS(progress, 0, max_progress);
00132 }
00133
00134
00135 for (i=0; i<m_num_classes; i++)
00136 {
00137 for (j=0; j<m_dim; j++)
00138 m_means(j, i) /= m_label_prob.vector[i];
00139
00140 progress++;
00141 SG_PROGRESS(progress, 0, max_progress);
00142 }
00143
00144
00145 for (i=0; i<train_labels.vlen; i++)
00146 {
00147 SGVector<float64_t> fea = m_features->get_computed_dot_feature_vector(i);
00148 for (j=0; j<m_dim; j++)
00149 {
00150 m_variances(j, train_labels.vector[i]) +=
00151 CMath::sq(fea[j]-m_means(j, train_labels.vector[i]));
00152 }
00153
00154 progress++;
00155 SG_PROGRESS(progress, 0, max_progress);
00156 }
00157
00158
00159 for (i=0; i<m_num_classes; i++)
00160 {
00161 for (j=0; j<m_dim; j++)
00162 m_variances(j, i) /= m_label_prob.vector[i] > 1 ? m_label_prob.vector[i]-1 : 1;
00163
00164
00165 m_label_prob.vector[i]/= m_num_classes;
00166
00167 progress++;
00168 SG_PROGRESS(progress, 0, max_progress);
00169 }
00170 SG_DONE();
00171
00172 return true;
00173 }
00174
00175 CMulticlassLabels* CGaussianNaiveBayes::apply_multiclass(CFeatures* data)
00176 {
00177 if (data)
00178 set_features(data);
00179
00180 ASSERT(m_features);
00181
00182
00183 int32_t num_vectors = m_features->get_num_vectors();
00184
00185
00186 CMulticlassLabels* result = new CMulticlassLabels(num_vectors);
00187
00188
00189 SG_PROGRESS(0, 0, num_vectors);
00190 for (int i = 0; i < num_vectors; i++)
00191 {
00192 result->set_label(i,apply_one(i));
00193 SG_PROGRESS(i + 1, 0, num_vectors);
00194 }
00195 SG_DONE();
00196 return result;
00197 };
00198
00199 float64_t CGaussianNaiveBayes::apply_one(int32_t idx)
00200 {
00201
00202 SGVector<float64_t> feature_vector = m_features->get_computed_dot_feature_vector(idx);
00203
00204
00205 int i,k;
00206
00207
00208 for (i=0; i<m_num_classes; i++)
00209 {
00210
00211 if (m_label_prob.vector[i]==0.0)
00212 {
00213 m_rates.vector[i] = 0.0;
00214 continue;
00215 }
00216 else
00217 m_rates.vector[i] = CMath::log(m_label_prob.vector[i]);
00218
00219
00220 for (k=0; k<m_dim; k++)
00221 if (m_variances(k,i)!=0.0)
00222 m_rates.vector[i]+= CMath::log(0.39894228/CMath::sqrt(m_variances(k, i))) -
00223 0.5*CMath::sq(feature_vector.vector[k]-m_means(k, i))/(m_variances(k, i));
00224 }
00225
00226
00227 int32_t max_label_idx = 0;
00228
00229 for (i=0; i<m_num_classes; i++)
00230 {
00231 if (m_rates.vector[i]>m_rates.vector[max_label_idx])
00232 max_label_idx = i;
00233 }
00234
00235 return max_label_idx+m_min_label;
00236 };