Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/classifier/NearestCentroid.h>
00011 #include <shogun/labels/MulticlassLabels.h>
00012 #include <shogun/features/Features.h>
00013 #include <shogun/features/FeatureTypes.h>
00014
00015
00016
00017 namespace shogun{
00018
00019 CNearestCentroid::CNearestCentroid() : CDistanceMachine()
00020 {
00021 init();
00022 }
00023
00024 CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine()
00025 {
00026 init();
00027 ASSERT(d);
00028 ASSERT(trainlab);
00029 set_distance(d);
00030 set_labels(trainlab);
00031 }
00032
00033 CNearestCentroid::~CNearestCentroid()
00034 {
00035 if(m_is_trained)
00036 distance->remove_lhs();
00037 else
00038 delete m_centroids;
00039 }
00040
00041 void CNearestCentroid::init()
00042 {
00043 m_shrinking=0;
00044 m_is_trained=false;
00045 m_centroids = new CDenseFeatures<float64_t>();
00046 }
00047
00048
00049 bool CNearestCentroid::train_machine(CFeatures* data)
00050 {
00051 ASSERT(m_labels);
00052 ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00053 ASSERT(distance);
00054 ASSERT( data->get_feature_class() == C_DENSE)
00055 if (data)
00056 {
00057 if (m_labels->get_num_labels() != data->get_num_vectors())
00058 SG_ERROR("Number of training vectors does not match number of labels\n");
00059 distance->init(data, data);
00060 }
00061 else
00062 {
00063 data = distance->get_lhs();
00064 }
00065 int32_t num_vectors = data->get_num_vectors();
00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067 int32_t num_feats = ((CDenseFeatures<float64_t>*) data)->get_num_features();
00068 SGMatrix<float64_t> centroids(num_feats,num_classes);
00069 centroids.zero();
00070
00071 m_centroids->set_num_features(num_feats);
00072 m_centroids->set_num_vectors(num_classes);
00073
00074 int64_t* num_per_class = new int64_t[num_classes];
00075 for (int32_t i=0 ; i<num_classes ; i++)
00076 {
00077 num_per_class[i]=0;
00078 }
00079
00080 for (int32_t idx=0 ; idx<num_vectors ; idx++)
00081 {
00082 int32_t current_len;
00083 bool current_free;
00084 int32_t current_class = ((CMulticlassLabels*) m_labels)->get_label(idx);
00085 float64_t* target = centroids.matrix + num_feats*current_class;
00086 float64_t* current = ((CDenseFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free);
00087 SGVector<float64_t>::add(target,1.0,target,1.0,current,current_len);
00088 num_per_class[current_class]++;
00089 ((CDenseFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free);
00090 }
00091
00092
00093 for (int32_t i=0 ; i<num_classes ; i++)
00094 {
00095 float64_t* target = centroids.matrix + num_feats*i;
00096 int32_t total = num_per_class[i];
00097 float64_t scale = 0;
00098 if(total>1)
00099 scale = 1.0/((float64_t)(total-1));
00100 else
00101 scale = 1.0/(float64_t)total;
00102
00103 SGVector<float64_t>::scale_vector(scale,target,num_feats);
00104 }
00105
00106 m_centroids->free_feature_matrix();
00107 m_centroids->set_feature_matrix(centroids);
00108
00109
00110 m_is_trained=true;
00111 distance->init(m_centroids,distance->get_rhs());
00112
00113 SG_FREE(num_per_class);
00114
00115 return true;
00116 }
00117
00118 }