LDA.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 
00013 #ifdef HAVE_LAPACK
00014 #include "classifier/Classifier.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "classifier/LDA.h"
00017 #include "features/Labels.h"
00018 #include "lib/Mathematics.h"
00019 #include "lib/lapack.h"
00020 
00021 using namespace shogun;
00022 
00023 CLDA::CLDA(float64_t gamma)
00024 : CLinearClassifier(), m_gamma(gamma)
00025 {
00026 }
00027 
00028 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00029 : CLinearClassifier(), m_gamma(gamma)
00030 {
00031     set_features(traindat);
00032     set_labels(trainlab);
00033 }
00034 
00035 
00036 CLDA::~CLDA()
00037 {
00038 }
00039 
00040 bool CLDA::train(CFeatures* data)
00041 {
00042     ASSERT(labels);
00043     if (data)
00044     {
00045         if (!data->has_property(FP_DOT))
00046             SG_ERROR("Specified features are not of type CDotFeatures\n");
00047         set_features((CDotFeatures*) data);
00048     }
00049     ASSERT(features);
00050     int32_t num_train_labels=0;
00051     int32_t* train_labels=labels->get_int_labels(num_train_labels);
00052     ASSERT(train_labels);
00053 
00054     int32_t num_feat=features->get_dim_feature_space();
00055     int32_t num_vec=features->get_num_vectors();
00056     ASSERT(num_vec==num_train_labels);
00057 
00058     int32_t* classidx_neg=new int32_t[num_vec];
00059     int32_t* classidx_pos=new int32_t[num_vec];
00060 
00061     int32_t i=0;
00062     int32_t j=0;
00063     int32_t num_neg=0;
00064     int32_t num_pos=0;
00065     for (i=0; i<num_train_labels; i++)
00066     {
00067         if (train_labels[i]==-1)
00068             classidx_neg[num_neg++]=i;
00069         else if (train_labels[i]==+1)
00070             classidx_pos[num_pos++]=i;
00071         else
00072         {
00073             SG_ERROR( "found label != +/- 1 bailing...");
00074             return false;
00075         }
00076     }
00077 
00078     if (num_neg<=0 && num_pos<=0)
00079     {
00080       SG_ERROR( "whooooo ? only a single class found\n");
00081         return false;
00082     }
00083 
00084     delete[] w;
00085     w=new float64_t[num_feat];
00086     w_dim=num_feat;
00087 
00088     float64_t* mean_neg=new float64_t[num_feat];
00089     memset(mean_neg,0,num_feat*sizeof(float64_t));
00090 
00091     float64_t* mean_pos=new float64_t[num_feat];
00092     memset(mean_pos,0,num_feat*sizeof(float64_t));
00093 
00094     /* calling external lib */
00095     double* scatter=new double[num_feat*num_feat];
00096     double* buffer=new double[num_feat*CMath::max(num_neg, num_pos)];
00097     int nf = (int) num_feat;
00098 
00099     CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00100     //mean neg
00101     for (i=0; i<num_neg; i++)
00102     {
00103         int32_t vlen;
00104         bool vfree;
00105         float64_t* vec=
00106             rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00107         ASSERT(vec);
00108 
00109         for (j=0; j<vlen; j++)
00110         {
00111             mean_neg[j]+=vec[j];
00112             buffer[num_feat*i+j]=vec[j];
00113         }
00114 
00115         rf->free_feature_vector(vec, classidx_neg[i], vfree);
00116     }
00117 
00118     for (j=0; j<num_feat; j++)
00119         mean_neg[j]/=num_neg;
00120 
00121     for (i=0; i<num_neg; i++)
00122     {
00123         for (j=0; j<num_feat; j++)
00124             buffer[num_feat*i+j]-=mean_neg[j];
00125     }
00126     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00127         (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00128     
00129     //mean pos
00130     for (i=0; i<num_pos; i++)
00131     {
00132         int32_t vlen;
00133         bool vfree;
00134         float64_t* vec=
00135             rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00136         ASSERT(vec);
00137 
00138         for (j=0; j<vlen; j++)
00139         {
00140             mean_pos[j]+=vec[j];
00141             buffer[num_feat*i+j]=vec[j];
00142         }
00143 
00144         rf->free_feature_vector(vec, classidx_pos[i], vfree);
00145     }
00146 
00147     for (j=0; j<num_feat; j++)
00148         mean_pos[j]/=num_pos;
00149 
00150     for (i=0; i<num_pos; i++)
00151     {
00152         for (j=0; j<num_feat; j++)
00153             buffer[num_feat*i+j]-=mean_pos[j];
00154     }
00155     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00156         1.0/(num_train_labels-1), buffer, nf, buffer, nf,
00157         1.0/(num_train_labels-1), scatter, nf);
00158 
00159     float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00160 
00161     double s=1.0-m_gamma; /* calling external lib; indirectly */
00162     for (i=0; i<num_feat*num_feat; i++)
00163         scatter[i]*=s;
00164 
00165     for (i=0; i<num_feat; i++)
00166         scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00167 
00168     double* inv_scatter= (double*) CMath::pinv(
00169         scatter, num_feat, num_feat, NULL);
00170 
00171     float64_t* w_pos=buffer;
00172     float64_t* w_neg=&buffer[num_feat];
00173 
00174     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00175         (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00176     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00177         (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00178     
00179     bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00180     for (i=0; i<num_feat; i++)
00181         w[i]=w_pos[i]-w_neg[i];
00182 
00183 #ifdef DEBUG_LDA
00184     SG_PRINT("bias: %f\n", bias);
00185     CMath::display_vector(w, num_feat, "w");
00186     CMath::display_vector(w_pos, num_feat, "w_pos");
00187     CMath::display_vector(w_neg, num_feat, "w_neg");
00188     CMath::display_vector(mean_pos, num_feat, "mean_pos");
00189     CMath::display_vector(mean_neg, num_feat, "mean_neg");
00190 #endif
00191 
00192     delete[] train_labels;
00193     delete[] mean_neg;
00194     delete[] mean_pos;
00195     delete[] scatter;
00196     delete[] inv_scatter;
00197     delete[] classidx_neg;
00198     delete[] classidx_pos;
00199     delete[] buffer;
00200     return true;
00201 }
00202 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation