00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012
00013 #ifdef HAVE_LAPACK
00014 #include "classifier/Classifier.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "classifier/LDA.h"
00017 #include "features/Labels.h"
00018 #include "lib/Mathematics.h"
00019 #include "lib/lapack.h"
00020
00021 using namespace shogun;
00022
00023 CLDA::CLDA(float64_t gamma)
00024 : CLinearClassifier(), m_gamma(gamma)
00025 {
00026 }
00027
00028 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00029 : CLinearClassifier(), m_gamma(gamma)
00030 {
00031 set_features(traindat);
00032 set_labels(trainlab);
00033 }
00034
00035
00036 CLDA::~CLDA()
00037 {
00038 }
00039
00040 bool CLDA::train(CFeatures* data)
00041 {
00042 ASSERT(labels);
00043 if (data)
00044 {
00045 if (!data->has_property(FP_DOT))
00046 SG_ERROR("Specified features are not of type CDotFeatures\n");
00047 set_features((CDotFeatures*) data);
00048 }
00049 ASSERT(features);
00050 int32_t num_train_labels=0;
00051 int32_t* train_labels=labels->get_int_labels(num_train_labels);
00052 ASSERT(train_labels);
00053
00054 int32_t num_feat=features->get_dim_feature_space();
00055 int32_t num_vec=features->get_num_vectors();
00056 ASSERT(num_vec==num_train_labels);
00057
00058 int32_t* classidx_neg=new int32_t[num_vec];
00059 int32_t* classidx_pos=new int32_t[num_vec];
00060
00061 int32_t i=0;
00062 int32_t j=0;
00063 int32_t num_neg=0;
00064 int32_t num_pos=0;
00065 for (i=0; i<num_train_labels; i++)
00066 {
00067 if (train_labels[i]==-1)
00068 classidx_neg[num_neg++]=i;
00069 else if (train_labels[i]==+1)
00070 classidx_pos[num_pos++]=i;
00071 else
00072 {
00073 SG_ERROR( "found label != +/- 1 bailing...");
00074 return false;
00075 }
00076 }
00077
00078 if (num_neg<=0 && num_pos<=0)
00079 {
00080 SG_ERROR( "whooooo ? only a single class found\n");
00081 return false;
00082 }
00083
00084 delete[] w;
00085 w=new float64_t[num_feat];
00086 w_dim=num_feat;
00087
00088 float64_t* mean_neg=new float64_t[num_feat];
00089 memset(mean_neg,0,num_feat*sizeof(float64_t));
00090
00091 float64_t* mean_pos=new float64_t[num_feat];
00092 memset(mean_pos,0,num_feat*sizeof(float64_t));
00093
00094
00095 double* scatter=new double[num_feat*num_feat];
00096 double* buffer=new double[num_feat*CMath::max(num_neg, num_pos)];
00097 int nf = (int) num_feat;
00098
00099 CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00100
00101 for (i=0; i<num_neg; i++)
00102 {
00103 int32_t vlen;
00104 bool vfree;
00105 float64_t* vec=
00106 rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00107 ASSERT(vec);
00108
00109 for (j=0; j<vlen; j++)
00110 {
00111 mean_neg[j]+=vec[j];
00112 buffer[num_feat*i+j]=vec[j];
00113 }
00114
00115 rf->free_feature_vector(vec, classidx_neg[i], vfree);
00116 }
00117
00118 for (j=0; j<num_feat; j++)
00119 mean_neg[j]/=num_neg;
00120
00121 for (i=0; i<num_neg; i++)
00122 {
00123 for (j=0; j<num_feat; j++)
00124 buffer[num_feat*i+j]-=mean_neg[j];
00125 }
00126 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00127 (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00128
00129
00130 for (i=0; i<num_pos; i++)
00131 {
00132 int32_t vlen;
00133 bool vfree;
00134 float64_t* vec=
00135 rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00136 ASSERT(vec);
00137
00138 for (j=0; j<vlen; j++)
00139 {
00140 mean_pos[j]+=vec[j];
00141 buffer[num_feat*i+j]=vec[j];
00142 }
00143
00144 rf->free_feature_vector(vec, classidx_pos[i], vfree);
00145 }
00146
00147 for (j=0; j<num_feat; j++)
00148 mean_pos[j]/=num_pos;
00149
00150 for (i=0; i<num_pos; i++)
00151 {
00152 for (j=0; j<num_feat; j++)
00153 buffer[num_feat*i+j]-=mean_pos[j];
00154 }
00155 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00156 1.0/(num_train_labels-1), buffer, nf, buffer, nf,
00157 1.0/(num_train_labels-1), scatter, nf);
00158
00159 float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00160
00161 double s=1.0-m_gamma;
00162 for (i=0; i<num_feat*num_feat; i++)
00163 scatter[i]*=s;
00164
00165 for (i=0; i<num_feat; i++)
00166 scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00167
00168 double* inv_scatter= (double*) CMath::pinv(
00169 scatter, num_feat, num_feat, NULL);
00170
00171 float64_t* w_pos=buffer;
00172 float64_t* w_neg=&buffer[num_feat];
00173
00174 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00175 (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00176 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00177 (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00178
00179 bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00180 for (i=0; i<num_feat; i++)
00181 w[i]=w_pos[i]-w_neg[i];
00182
00183 #ifdef DEBUG_LDA
00184 SG_PRINT("bias: %f\n", bias);
00185 CMath::display_vector(w, num_feat, "w");
00186 CMath::display_vector(w_pos, num_feat, "w_pos");
00187 CMath::display_vector(w_neg, num_feat, "w_neg");
00188 CMath::display_vector(mean_pos, num_feat, "mean_pos");
00189 CMath::display_vector(mean_neg, num_feat, "mean_neg");
00190 #endif
00191
00192 delete[] train_labels;
00193 delete[] mean_neg;
00194 delete[] mean_pos;
00195 delete[] scatter;
00196 delete[] inv_scatter;
00197 delete[] classidx_neg;
00198 delete[] classidx_pos;
00199 delete[] buffer;
00200 return true;
00201 }
00202 #endif