ConjugateIndex.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/multiclass/ConjugateIndex.h>
00012 #ifdef HAVE_LAPACK
00013 #include <shogun/machine/Machine.h>
00014 #include <shogun/features/Features.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/MulticlassLabels.h>
00017 #include <shogun/mathematics/lapack.h>
00018 #include <shogun/mathematics/Math.h>
00019 #include <shogun/lib/Signal.h>
00020 
00021 using namespace shogun;
00022 
00023 CConjugateIndex::CConjugateIndex() : CMachine()
00024 {
00025     m_classes = NULL;
00026     m_features = NULL;
00027 };
00028 
00029 CConjugateIndex::CConjugateIndex(CFeatures* train_features, CLabels* train_labels) : CMachine()
00030 {
00031     m_features = NULL;
00032     set_features(train_features);
00033     set_labels(train_labels);
00034     m_classes = NULL;
00035 };
00036 
00037 CConjugateIndex::~CConjugateIndex()
00038 {
00039     clean_classes();
00040     SG_UNREF(m_features);
00041 };
00042 
00043 void CConjugateIndex::set_features(CFeatures* features)
00044 {
00045     ASSERT(features->get_feature_class()==C_DENSE);
00046     SG_REF(features);
00047     SG_UNREF(m_features);
00048     m_features = (CDenseFeatures<float64_t>*)features;
00049 }
00050 
00051 CDenseFeatures<float64_t>* CConjugateIndex::get_features()
00052 {
00053     SG_REF(m_features);
00054     return m_features;
00055 }
00056 
00057 void CConjugateIndex::clean_classes()
00058 {
00059     if (m_classes)
00060     {
00061         for (int32_t i=0; i<m_num_classes; i++)
00062             m_classes[i]=SGMatrix<float64_t>();
00063 
00064         delete[] m_classes;
00065     }
00066 }
00067 
00068 bool CConjugateIndex::train_machine(CFeatures* data)
00069 {
00070     if (data)
00071         set_features(data);
00072 
00073     ASSERT(m_labels);
00074     ASSERT(m_labels->get_label_type()==LT_MULTICLASS);
00075 
00076     m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00077     ASSERT(m_num_classes>=2);
00078     clean_classes();
00079 
00080     int32_t num_vectors;
00081     int32_t num_features;
00082     float64_t* feature_matrix = m_features->get_feature_matrix(num_features,num_vectors);
00083 
00084     m_classes = new SGMatrix<float64_t>[m_num_classes]();
00085     for (int32_t i=0; i<m_num_classes; i++)
00086         m_classes[i] = SGMatrix<float64_t>(num_features,num_features);
00087 
00088     m_feature_vector = SGVector<float64_t>(num_features);
00089 
00090     SG_PROGRESS(0,0,m_num_classes-1);
00091 
00092     for (int32_t label=0; label<m_num_classes; label++)
00093     {
00094         int32_t count = 0;
00095         for (int32_t i=0; i<num_vectors; i++)
00096         {
00097             if (((CMulticlassLabels*) m_labels)->get_int_label(i) == label)
00098                 count++;
00099         }
00100 
00101         SGMatrix<float64_t> class_feature_matrix(num_features,count);
00102         SGMatrix<float64_t> matrix(count,count);
00103         SGMatrix<float64_t> helper_matrix(num_features,count);
00104 
00105         count = 0;
00106         for (int32_t i=0; i<num_vectors; i++)
00107         {
00108             if (((CMulticlassLabels*) m_labels)->get_label(i) == label)
00109             {
00110                 memcpy(class_feature_matrix.matrix+count*num_features,
00111                        feature_matrix+i*num_features,
00112                        sizeof(float64_t)*num_features);
00113                 count++;
00114             }
00115         }
00116 
00117         cblas_dgemm(CblasColMajor,CblasTrans,CblasNoTrans,
00118                     count,count,num_features,
00119                     1.0,class_feature_matrix.matrix,num_features,
00120                     class_feature_matrix.matrix,num_features,
00121                     0.0,matrix.matrix,count);
00122 
00123         SGMatrix<float64_t>::inverse(matrix);
00124 
00125         cblas_dgemm(CblasColMajor,CblasNoTrans,CblasTrans,
00126                     count,num_features,count,
00127                     1.0,matrix.matrix,count,
00128                     class_feature_matrix.matrix,num_features,
00129                     0.0,helper_matrix.matrix,count);
00130 
00131         cblas_dgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,
00132                     num_features,num_features,count,
00133                     1.0,class_feature_matrix.matrix,num_features,
00134                     helper_matrix.matrix,count,
00135                     0.0,m_classes[label].matrix,num_features);
00136 
00137         SG_PROGRESS(label+1,0,m_num_classes);
00138     }
00139     SG_DONE();
00140 
00141     return true;
00142 };
00143 
00144 CMulticlassLabels* CConjugateIndex::apply_multiclass(CFeatures* data)
00145 {
00146     if (data)
00147         set_features(data);
00148 
00149     ASSERT(m_features);
00150 
00151     ASSERT(m_classes);
00152     ASSERT(m_num_classes>1);
00153     ASSERT(m_features->get_num_features()==m_feature_vector.vlen);
00154 
00155     int32_t num_vectors = m_features->get_num_vectors();
00156 
00157     CMulticlassLabels* predicted_labels = new CMulticlassLabels(num_vectors);
00158 
00159     for (int32_t i=0; i<num_vectors;i++)
00160     {
00161         SG_PROGRESS(i,0,num_vectors-1);
00162         predicted_labels->set_label(i,apply_one(i));
00163     }
00164     SG_DONE();
00165 
00166     return predicted_labels;
00167 };
00168 
00169 float64_t CConjugateIndex::conjugate_index(SGVector<float64_t> feature_vector, int32_t label)
00170 {
00171     int32_t num_features = feature_vector.vlen;
00172     float64_t norm = cblas_ddot(num_features,feature_vector.vector,1,
00173                                 feature_vector.vector,1);
00174 
00175     cblas_dgemv(CblasColMajor,CblasNoTrans,
00176                 num_features,num_features,
00177                 1.0,m_classes[label].matrix,num_features,
00178                 feature_vector.vector,1,
00179                 0.0,m_feature_vector.vector,1);
00180 
00181     float64_t product = cblas_ddot(num_features,feature_vector.vector,1,
00182                                    m_feature_vector.vector,1);
00183     return product/norm;
00184 };
00185 
00186 float64_t CConjugateIndex::apply_one(int32_t index)
00187 {
00188     int32_t predicted_label = 0;
00189     float64_t max_conjugate_index = 0.0;
00190     float64_t current_conjugate_index;
00191 
00192     SGVector<float64_t> feature_vector = m_features->get_feature_vector(index);
00193     for (int32_t i=0; i<m_num_classes; i++)
00194     {
00195         current_conjugate_index = conjugate_index(feature_vector,i);
00196 
00197         if (current_conjugate_index > max_conjugate_index)
00198         {
00199             max_conjugate_index = current_conjugate_index;
00200             predicted_label = i;
00201         }
00202     }
00203 
00204     return predicted_label;
00205 };
00206 
00207 #endif /* HAVE_LAPACK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation