LDA.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 
00013 #ifdef HAVE_LAPACK
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/machine/LinearMachine.h>
00016 #include <shogun/classifier/LDA.h>
00017 #include <shogun/features/Labels.h>
00018 #include <shogun/mathematics/Math.h>
00019 #include <shogun/mathematics/lapack.h>
00020 
00021 using namespace shogun;
00022 
00023 CLDA::CLDA(float64_t gamma)
00024 : CLinearMachine(), m_gamma(gamma)
00025 {
00026 }
00027 
00028 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00029 : CLinearMachine(), m_gamma(gamma)
00030 {
00031     set_features(traindat);
00032     set_labels(trainlab);
00033 }
00034 
00035 
00036 CLDA::~CLDA()
00037 {
00038 }
00039 
00040 bool CLDA::train_machine(CFeatures* data)
00041 {
00042     ASSERT(labels);
00043     if (data)
00044     {
00045         if (!data->has_property(FP_DOT))
00046             SG_ERROR("Specified features are not of type CDotFeatures\n");
00047         set_features((CDotFeatures*) data);
00048     }
00049     ASSERT(features);
00050     SGVector<int32_t> train_labels=labels->get_int_labels();
00051     ASSERT(train_labels.vector);
00052 
00053     int32_t num_feat=features->get_dim_feature_space();
00054     int32_t num_vec=features->get_num_vectors();
00055     ASSERT(num_vec==train_labels.vlen);
00056 
00057     int32_t* classidx_neg=SG_MALLOC(int32_t, num_vec);
00058     int32_t* classidx_pos=SG_MALLOC(int32_t, num_vec);
00059 
00060     int32_t i=0;
00061     int32_t j=0;
00062     int32_t num_neg=0;
00063     int32_t num_pos=0;
00064     for (i=0; i<train_labels.vlen; i++)
00065     {
00066         if (train_labels.vector[i]==-1)
00067             classidx_neg[num_neg++]=i;
00068         else if (train_labels.vector[i]==+1)
00069             classidx_pos[num_pos++]=i;
00070         else
00071         {
00072             SG_ERROR( "found label != +/- 1 bailing...");
00073             return false;
00074         }
00075     }
00076 
00077     if (num_neg<=0 && num_pos<=0)
00078     {
00079       SG_ERROR( "whooooo ? only a single class found\n");
00080         return false;
00081     }
00082 
00083     SG_FREE(w);
00084     w=SG_MALLOC(float64_t, num_feat);
00085     w_dim=num_feat;
00086 
00087     float64_t* mean_neg=SG_MALLOC(float64_t, num_feat);
00088     memset(mean_neg,0,num_feat*sizeof(float64_t));
00089 
00090     float64_t* mean_pos=SG_MALLOC(float64_t, num_feat);
00091     memset(mean_pos,0,num_feat*sizeof(float64_t));
00092 
00093     /* calling external lib */
00094     double* scatter=SG_MALLOC(double, num_feat*num_feat);
00095     double* buffer=SG_MALLOC(double, num_feat*CMath::max(num_neg, num_pos));
00096     int nf = (int) num_feat;
00097 
00098     CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00099     //mean neg
00100     for (i=0; i<num_neg; i++)
00101     {
00102         int32_t vlen;
00103         bool vfree;
00104         float64_t* vec=
00105             rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00106         ASSERT(vec);
00107 
00108         for (j=0; j<vlen; j++)
00109         {
00110             mean_neg[j]+=vec[j];
00111             buffer[num_feat*i+j]=vec[j];
00112         }
00113 
00114         rf->free_feature_vector(vec, classidx_neg[i], vfree);
00115     }
00116 
00117     for (j=0; j<num_feat; j++)
00118         mean_neg[j]/=num_neg;
00119 
00120     for (i=0; i<num_neg; i++)
00121     {
00122         for (j=0; j<num_feat; j++)
00123             buffer[num_feat*i+j]-=mean_neg[j];
00124     }
00125     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00126         (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00127     
00128     //mean pos
00129     for (i=0; i<num_pos; i++)
00130     {
00131         int32_t vlen;
00132         bool vfree;
00133         float64_t* vec=
00134             rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00135         ASSERT(vec);
00136 
00137         for (j=0; j<vlen; j++)
00138         {
00139             mean_pos[j]+=vec[j];
00140             buffer[num_feat*i+j]=vec[j];
00141         }
00142 
00143         rf->free_feature_vector(vec, classidx_pos[i], vfree);
00144     }
00145 
00146     for (j=0; j<num_feat; j++)
00147         mean_pos[j]/=num_pos;
00148 
00149     for (i=0; i<num_pos; i++)
00150     {
00151         for (j=0; j<num_feat; j++)
00152             buffer[num_feat*i+j]-=mean_pos[j];
00153     }
00154     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00155         1.0/(train_labels.vlen-1), buffer, nf, buffer, nf,
00156         1.0/(train_labels.vlen-1), scatter, nf);
00157 
00158     float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00159 
00160     double s=1.0-m_gamma; /* calling external lib; indirectly */
00161     for (i=0; i<num_feat*num_feat; i++)
00162         scatter[i]*=s;
00163 
00164     for (i=0; i<num_feat; i++)
00165         scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00166 
00167     double* inv_scatter= (double*) CMath::pinv(
00168         scatter, num_feat, num_feat, NULL);
00169 
00170     float64_t* w_pos=buffer;
00171     float64_t* w_neg=&buffer[num_feat];
00172 
00173     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00174         (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00175     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00176         (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00177     
00178     bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00179     for (i=0; i<num_feat; i++)
00180         w[i]=w_pos[i]-w_neg[i];
00181 
00182 #ifdef DEBUG_LDA
00183     SG_PRINT("bias: %f\n", bias);
00184     CMath::display_vector(w, num_feat, "w");
00185     CMath::display_vector(w_pos, num_feat, "w_pos");
00186     CMath::display_vector(w_neg, num_feat, "w_neg");
00187     CMath::display_vector(mean_pos, num_feat, "mean_pos");
00188     CMath::display_vector(mean_neg, num_feat, "mean_neg");
00189 #endif
00190 
00191     train_labels.free_vector();
00192     SG_FREE(mean_neg);
00193     SG_FREE(mean_pos);
00194     SG_FREE(scatter);
00195     SG_FREE(inv_scatter);
00196     SG_FREE(classidx_neg);
00197     SG_FREE(classidx_pos);
00198     SG_FREE(buffer);
00199     return true;
00200 }
00201 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation