NearestCentroid.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Philippe Tillet
00008  */
00009 
00010 #include <shogun/classifier/NearestCentroid.h>
00011 #include <shogun/labels/MulticlassLabels.h>
00012 #include <shogun/features/Features.h>
00013 #include <shogun/features/FeatureTypes.h>
00014 
00015 
00016 
00017 namespace shogun{
00018     
00019     CNearestCentroid::CNearestCentroid() : CDistanceMachine()
00020     {
00021         init();
00022     }
00023 
00024     CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine()
00025     {
00026         init();
00027         ASSERT(d);
00028         ASSERT(trainlab);
00029         set_distance(d);
00030         set_labels(trainlab);
00031     }
00032 
00033     CNearestCentroid::~CNearestCentroid()
00034     {
00035         if(m_is_trained)
00036             distance->remove_lhs();
00037         else
00038             delete m_centroids;
00039     }
00040 
00041     void CNearestCentroid::init()
00042     {
00043         m_shrinking=0;
00044         m_is_trained=false;
00045         m_centroids = new CDenseFeatures<float64_t>();
00046     }
00047 
00048 
00049     bool CNearestCentroid::train_machine(CFeatures* data)
00050     {
00051         ASSERT(m_labels);
00052         ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00053         ASSERT(distance);
00054         ASSERT( data->get_feature_class() == C_DENSE)
00055         if (data)
00056         {
00057             if (m_labels->get_num_labels() != data->get_num_vectors())
00058                 SG_ERROR("Number of training vectors does not match number of labels\n");
00059             distance->init(data, data);
00060         }
00061         else
00062         {
00063             data = distance->get_lhs();
00064         }
00065         int32_t num_vectors = data->get_num_vectors();
00066         int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067         int32_t num_feats = ((CDenseFeatures<float64_t>*) data)->get_num_features();
00068         SGMatrix<float64_t> centroids(num_feats,num_classes);
00069         centroids.zero();
00070 
00071         m_centroids->set_num_features(num_feats);
00072         m_centroids->set_num_vectors(num_classes);
00073         
00074         int64_t* num_per_class = new int64_t[num_classes];
00075         for (int32_t i=0 ; i<num_classes ; i++)
00076         {
00077             num_per_class[i]=0;
00078         }
00079         
00080         for (int32_t idx=0 ; idx<num_vectors ; idx++)
00081         {
00082             int32_t current_len;
00083             bool current_free;
00084             int32_t current_class = ((CMulticlassLabels*) m_labels)->get_label(idx);
00085             float64_t* target = centroids.matrix + num_feats*current_class;
00086             float64_t* current = ((CDenseFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free);
00087             SGVector<float64_t>::add(target,1.0,target,1.0,current,current_len);
00088             num_per_class[current_class]++;
00089             ((CDenseFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free);
00090         }
00091 
00092 
00093         for (int32_t i=0 ; i<num_classes ; i++)
00094         {
00095             float64_t* target = centroids.matrix + num_feats*i;
00096             int32_t total = num_per_class[i];
00097             float64_t scale = 0;
00098             if(total>1)
00099                 scale = 1.0/((float64_t)(total-1));
00100             else
00101                 scale = 1.0/(float64_t)total;
00102                 
00103             SGVector<float64_t>::scale_vector(scale,target,num_feats);
00104         }
00105                 
00106         m_centroids->free_feature_matrix();
00107         m_centroids->set_feature_matrix(centroids);
00108         
00109         
00110         m_is_trained=true;
00111         distance->init(m_centroids,distance->get_rhs());
00112         
00113         SG_FREE(num_per_class);
00114         
00115         return true;
00116     }
00117 
00118 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation