LDA.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 
00013 #ifdef HAVE_LAPACK
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/machine/LinearMachine.h>
00016 #include <shogun/classifier/LDA.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/mathematics/Math.h>
00020 #include <shogun/mathematics/lapack.h>
00021 
00022 using namespace shogun;
00023 
00024 CLDA::CLDA(float64_t gamma)
00025 : CLinearMachine(), m_gamma(gamma)
00026 {
00027 }
00028 
00029 CLDA::CLDA(float64_t gamma, CDenseFeatures<float64_t>* traindat, CLabels* trainlab)
00030 : CLinearMachine(), m_gamma(gamma)
00031 {
00032     set_features(traindat);
00033     set_labels(trainlab);
00034 }
00035 
00036 
00037 CLDA::~CLDA()
00038 {
00039 }
00040 
00041 bool CLDA::train_machine(CFeatures* data)
00042 {
00043     ASSERT(m_labels);
00044     if (data)
00045     {
00046         if (!data->has_property(FP_DOT))
00047             SG_ERROR("Specified features are not of type CDotFeatures\n");
00048         set_features((CDotFeatures*) data);
00049     }
00050     ASSERT(features);
00051     SGVector<int32_t> train_labels=((CBinaryLabels*) m_labels)->get_int_labels();
00052     ASSERT(train_labels.vector);
00053 
00054     int32_t num_feat=features->get_dim_feature_space();
00055     int32_t num_vec=features->get_num_vectors();
00056     ASSERT(num_vec==train_labels.vlen);
00057 
00058     int32_t* classidx_neg=SG_MALLOC(int32_t, num_vec);
00059     int32_t* classidx_pos=SG_MALLOC(int32_t, num_vec);
00060 
00061     int32_t i=0;
00062     int32_t j=0;
00063     int32_t num_neg=0;
00064     int32_t num_pos=0;
00065     for (i=0; i<train_labels.vlen; i++)
00066     {
00067         if (train_labels.vector[i]==-1)
00068             classidx_neg[num_neg++]=i;
00069         else if (train_labels.vector[i]==+1)
00070             classidx_pos[num_pos++]=i;
00071         else
00072         {
00073             SG_ERROR( "found label != +/- 1 bailing...");
00074             return false;
00075         }
00076     }
00077 
00078     if (num_neg<=0 || num_pos<=0)
00079     {
00080         SG_ERROR( "whooooo ? only a single class found\n");
00081         return false;
00082     }
00083 
00084     w=SGVector<float64_t>(num_feat);
00085 
00086     float64_t* mean_neg=SG_MALLOC(float64_t, num_feat);
00087     memset(mean_neg,0,num_feat*sizeof(float64_t));
00088 
00089     float64_t* mean_pos=SG_MALLOC(float64_t, num_feat);
00090     memset(mean_pos,0,num_feat*sizeof(float64_t));
00091 
00092     /* calling external lib */
00093     double* scatter=SG_MALLOC(double, num_feat*num_feat);
00094     double* buffer=SG_MALLOC(double, num_feat*CMath::max(num_neg, num_pos));
00095     int nf = (int) num_feat;
00096 
00097     CDenseFeatures<float64_t>* rf = (CDenseFeatures<float64_t>*) features;
00098     //mean neg
00099     for (i=0; i<num_neg; i++)
00100     {
00101         int32_t vlen;
00102         bool vfree;
00103         float64_t* vec=
00104             rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00105         ASSERT(vec);
00106 
00107         for (j=0; j<vlen; j++)
00108         {
00109             mean_neg[j]+=vec[j];
00110             buffer[num_feat*i+j]=vec[j];
00111         }
00112 
00113         rf->free_feature_vector(vec, classidx_neg[i], vfree);
00114     }
00115 
00116     for (j=0; j<num_feat; j++)
00117         mean_neg[j]/=num_neg;
00118 
00119     for (i=0; i<num_neg; i++)
00120     {
00121         for (j=0; j<num_feat; j++)
00122             buffer[num_feat*i+j]-=mean_neg[j];
00123     }
00124     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00125         (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00126 
00127     //mean pos
00128     for (i=0; i<num_pos; i++)
00129     {
00130         int32_t vlen;
00131         bool vfree;
00132         float64_t* vec=
00133             rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00134         ASSERT(vec);
00135 
00136         for (j=0; j<vlen; j++)
00137         {
00138             mean_pos[j]+=vec[j];
00139             buffer[num_feat*i+j]=vec[j];
00140         }
00141 
00142         rf->free_feature_vector(vec, classidx_pos[i], vfree);
00143     }
00144 
00145     for (j=0; j<num_feat; j++)
00146         mean_pos[j]/=num_pos;
00147 
00148     for (i=0; i<num_pos; i++)
00149     {
00150         for (j=0; j<num_feat; j++)
00151             buffer[num_feat*i+j]-=mean_pos[j];
00152     }
00153     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00154         1.0/(train_labels.vlen-1), buffer, nf, buffer, nf,
00155         1.0/(train_labels.vlen-1), scatter, nf);
00156 
00157     float64_t trace=SGMatrix<float64_t>::trace((float64_t*) scatter, num_feat, num_feat);
00158 
00159     double s=1.0-m_gamma; /* calling external lib; indirectly */
00160     for (i=0; i<num_feat*num_feat; i++)
00161         scatter[i]*=s;
00162 
00163     for (i=0; i<num_feat; i++)
00164         scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00165 
00166     double* inv_scatter= (double*) SGMatrix<float64_t>::pinv(
00167         scatter, num_feat, num_feat, NULL);
00168 
00169     float64_t* w_pos=buffer;
00170     float64_t* w_neg=&buffer[num_feat];
00171 
00172     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00173         (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00174     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00175         (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00176 
00177     bias=0.5*(SGVector<float64_t>::dot(w_neg, mean_neg, num_feat)-SGVector<float64_t>::dot(w_pos, mean_pos, num_feat));
00178     for (i=0; i<num_feat; i++)
00179         w.vector[i]=w_pos[i]-w_neg[i];
00180 
00181 #ifdef DEBUG_LDA
00182     SG_PRINT("bias: %f\n", bias);
00183     SGVector<float64_t>::display_vector(w.vector, num_feat, "w");
00184     SGVector<float64_t>::display_vector(w_pos, num_feat, "w_pos");
00185     SGVector<float64_t>::display_vector(w_neg, num_feat, "w_neg");
00186     SGVector<float64_t>::display_vector(mean_pos, num_feat, "mean_pos");
00187     SGVector<float64_t>::display_vector(mean_neg, num_feat, "mean_neg");
00188 #endif
00189 
00190     SG_FREE(mean_neg);
00191     SG_FREE(mean_pos);
00192     SG_FREE(scatter);
00193     SG_FREE(inv_scatter);
00194     SG_FREE(classidx_neg);
00195     SG_FREE(classidx_pos);
00196     SG_FREE(buffer);
00197     return true;
00198 }
00199 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation