00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/common.h>
00012
00013 #ifdef HAVE_LAPACK
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/machine/LinearMachine.h>
00016 #include <shogun/classifier/LDA.h>
00017 #include <shogun/features/Labels.h>
00018 #include <shogun/mathematics/Math.h>
00019 #include <shogun/mathematics/lapack.h>
00020
00021 using namespace shogun;
00022
00023 CLDA::CLDA(float64_t gamma)
00024 : CLinearMachine(), m_gamma(gamma)
00025 {
00026 }
00027
00028 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00029 : CLinearMachine(), m_gamma(gamma)
00030 {
00031 set_features(traindat);
00032 set_labels(trainlab);
00033 }
00034
00035
00036 CLDA::~CLDA()
00037 {
00038 }
00039
00040 bool CLDA::train_machine(CFeatures* data)
00041 {
00042 ASSERT(labels);
00043 if (data)
00044 {
00045 if (!data->has_property(FP_DOT))
00046 SG_ERROR("Specified features are not of type CDotFeatures\n");
00047 set_features((CDotFeatures*) data);
00048 }
00049 ASSERT(features);
00050 SGVector<int32_t> train_labels=labels->get_int_labels();
00051 ASSERT(train_labels.vector);
00052
00053 int32_t num_feat=features->get_dim_feature_space();
00054 int32_t num_vec=features->get_num_vectors();
00055 ASSERT(num_vec==train_labels.vlen);
00056
00057 int32_t* classidx_neg=SG_MALLOC(int32_t, num_vec);
00058 int32_t* classidx_pos=SG_MALLOC(int32_t, num_vec);
00059
00060 int32_t i=0;
00061 int32_t j=0;
00062 int32_t num_neg=0;
00063 int32_t num_pos=0;
00064 for (i=0; i<train_labels.vlen; i++)
00065 {
00066 if (train_labels.vector[i]==-1)
00067 classidx_neg[num_neg++]=i;
00068 else if (train_labels.vector[i]==+1)
00069 classidx_pos[num_pos++]=i;
00070 else
00071 {
00072 SG_ERROR( "found label != +/- 1 bailing...");
00073 return false;
00074 }
00075 }
00076
00077 if (num_neg<=0 && num_pos<=0)
00078 {
00079 SG_ERROR( "whooooo ? only a single class found\n");
00080 return false;
00081 }
00082
00083 SG_FREE(w);
00084 w=SG_MALLOC(float64_t, num_feat);
00085 w_dim=num_feat;
00086
00087 float64_t* mean_neg=SG_MALLOC(float64_t, num_feat);
00088 memset(mean_neg,0,num_feat*sizeof(float64_t));
00089
00090 float64_t* mean_pos=SG_MALLOC(float64_t, num_feat);
00091 memset(mean_pos,0,num_feat*sizeof(float64_t));
00092
00093
00094 double* scatter=SG_MALLOC(double, num_feat*num_feat);
00095 double* buffer=SG_MALLOC(double, num_feat*CMath::max(num_neg, num_pos));
00096 int nf = (int) num_feat;
00097
00098 CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00099
00100 for (i=0; i<num_neg; i++)
00101 {
00102 int32_t vlen;
00103 bool vfree;
00104 float64_t* vec=
00105 rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00106 ASSERT(vec);
00107
00108 for (j=0; j<vlen; j++)
00109 {
00110 mean_neg[j]+=vec[j];
00111 buffer[num_feat*i+j]=vec[j];
00112 }
00113
00114 rf->free_feature_vector(vec, classidx_neg[i], vfree);
00115 }
00116
00117 for (j=0; j<num_feat; j++)
00118 mean_neg[j]/=num_neg;
00119
00120 for (i=0; i<num_neg; i++)
00121 {
00122 for (j=0; j<num_feat; j++)
00123 buffer[num_feat*i+j]-=mean_neg[j];
00124 }
00125 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00126 (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00127
00128
00129 for (i=0; i<num_pos; i++)
00130 {
00131 int32_t vlen;
00132 bool vfree;
00133 float64_t* vec=
00134 rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00135 ASSERT(vec);
00136
00137 for (j=0; j<vlen; j++)
00138 {
00139 mean_pos[j]+=vec[j];
00140 buffer[num_feat*i+j]=vec[j];
00141 }
00142
00143 rf->free_feature_vector(vec, classidx_pos[i], vfree);
00144 }
00145
00146 for (j=0; j<num_feat; j++)
00147 mean_pos[j]/=num_pos;
00148
00149 for (i=0; i<num_pos; i++)
00150 {
00151 for (j=0; j<num_feat; j++)
00152 buffer[num_feat*i+j]-=mean_pos[j];
00153 }
00154 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00155 1.0/(train_labels.vlen-1), buffer, nf, buffer, nf,
00156 1.0/(train_labels.vlen-1), scatter, nf);
00157
00158 float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00159
00160 double s=1.0-m_gamma;
00161 for (i=0; i<num_feat*num_feat; i++)
00162 scatter[i]*=s;
00163
00164 for (i=0; i<num_feat; i++)
00165 scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00166
00167 double* inv_scatter= (double*) CMath::pinv(
00168 scatter, num_feat, num_feat, NULL);
00169
00170 float64_t* w_pos=buffer;
00171 float64_t* w_neg=&buffer[num_feat];
00172
00173 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00174 (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00175 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00176 (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00177
00178 bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00179 for (i=0; i<num_feat; i++)
00180 w[i]=w_pos[i]-w_neg[i];
00181
00182 #ifdef DEBUG_LDA
00183 SG_PRINT("bias: %f\n", bias);
00184 CMath::display_vector(w, num_feat, "w");
00185 CMath::display_vector(w_pos, num_feat, "w_pos");
00186 CMath::display_vector(w_neg, num_feat, "w_neg");
00187 CMath::display_vector(mean_pos, num_feat, "mean_pos");
00188 CMath::display_vector(mean_neg, num_feat, "mean_neg");
00189 #endif
00190
00191 train_labels.free_vector();
00192 SG_FREE(mean_neg);
00193 SG_FREE(mean_pos);
00194 SG_FREE(scatter);
00195 SG_FREE(inv_scatter);
00196 SG_FREE(classidx_neg);
00197 SG_FREE(classidx_pos);
00198 SG_FREE(buffer);
00199 return true;
00200 }
00201 #endif