QDA.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 
00013 #ifdef HAVE_LAPACK
00014 
00015 #include <shogun/multiclass/QDA.h>
00016 #include <shogun/machine/NativeMulticlassMachine.h>
00017 #include <shogun/features/Features.h>
00018 #include <shogun/labels/Labels.h>
00019 #include <shogun/labels/MulticlassLabels.h>
00020 #include <shogun/mathematics/Math.h>
00021 #include <shogun/mathematics/lapack.h>
00022 
00023 using namespace shogun;
00024 
00025 CQDA::CQDA(float64_t tolerance, bool store_covs)
00026 : CNativeMulticlassMachine(), m_tolerance(tolerance), 
00027     m_store_covs(store_covs), m_num_classes(0), m_dim(0)
00028 {
00029     init();
00030 }
00031 
00032 CQDA::CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance, bool store_covs)
00033 : CNativeMulticlassMachine(), m_tolerance(tolerance), m_store_covs(store_covs), m_num_classes(0), m_dim(0)
00034 {
00035     init();
00036     set_features(traindat);
00037     set_labels(trainlab);
00038 }
00039 
00040 CQDA::~CQDA()
00041 {
00042     SG_UNREF(m_features);
00043 
00044     cleanup();
00045 }
00046 
00047 void CQDA::init()
00048 {
00049     SG_ADD(&m_tolerance, "m_tolerance", "Tolerance member.", MS_AVAILABLE);
00050     SG_ADD(&m_store_covs, "m_store_covs", "Store covariances member", MS_NOT_AVAILABLE);
00051     SG_ADD((CSGObject**) &m_features, "m_features", "Feature object.", MS_NOT_AVAILABLE);
00052     SG_ADD(&m_means, "m_means", "Mean vectors list", MS_NOT_AVAILABLE);
00053     SG_ADD(&m_slog, "m_slog", "Vector used in classification", MS_NOT_AVAILABLE);
00054 
00055     //TODO include SGNDArray objects for serialization
00056 
00057     m_features  = NULL;
00058 }
00059 
00060 void CQDA::cleanup()
00061 {
00062     m_means=SGMatrix<float64_t>();
00063 
00064     m_num_classes = 0;
00065 }
00066 
00067 CMulticlassLabels* CQDA::apply_multiclass(CFeatures* data)
00068 {
00069     if (data)
00070     {
00071         if (!data->has_property(FP_DOT))
00072             SG_ERROR("Specified features are not of type CDotFeatures\n");
00073 
00074         set_features((CDotFeatures*) data);
00075     }
00076 
00077     if ( !m_features )
00078         return NULL;
00079 
00080     int32_t num_vecs = m_features->get_num_vectors();
00081     ASSERT(num_vecs > 0);
00082     ASSERT( m_dim == m_features->get_dim_feature_space() );
00083 
00084     CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features;
00085 
00086     SGMatrix< float64_t > X(num_vecs, m_dim);
00087     SGMatrix< float64_t > A(num_vecs, m_dim);
00088     SGVector< float64_t > norm2(num_vecs*m_num_classes);
00089     norm2.zero();
00090 
00091     int i, j, k, vlen;
00092     bool vfree;
00093     float64_t* vec;
00094     for ( k = 0 ; k < m_num_classes ; ++k )
00095     {
00096         // X = features - means
00097         for ( i = 0 ; i < num_vecs ; ++i )
00098         {
00099             vec = rf->get_feature_vector(i, vlen, vfree);
00100             ASSERT(vec);
00101 
00102             for ( j = 0 ; j < m_dim ; ++j )
00103                 X[i + j*num_vecs] = vec[j] - m_means[k*m_dim + j];
00104 
00105             rf->free_feature_vector(vec, i, vfree);
00106 
00107         }
00108 
00109         cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, num_vecs, m_dim,
00110             m_dim, 1.0, X.matrix, num_vecs, m_M.get_matrix(k), m_dim, 0.0,
00111             A.matrix, num_vecs);
00112 
00113         for ( i = 0 ; i < num_vecs ; ++i )
00114             for ( j = 0 ; j < m_dim ; ++j )
00115                 norm2[i + k*num_vecs] += CMath::sq(A[i + j*num_vecs]);
00116 
00117 #ifdef DEBUG_QDA
00118     CMath::display_matrix(A.matrix, num_vecs, m_dim, "A");
00119 #endif
00120     }
00121 
00122     for ( i = 0 ; i < num_vecs ; ++i )
00123         for ( k = 0 ; k < m_num_classes ; ++k )
00124         {
00125             norm2[i + k*num_vecs] += m_slog[k];
00126             norm2[i + k*num_vecs] *= -0.5;
00127         }
00128 
00129     CMulticlassLabels* out = new CMulticlassLabels(num_vecs);
00130 
00131     for ( i = 0 ; i < num_vecs ; ++i )
00132         out->set_label(i, SGVector<float64_t>::arg_max(norm2.vector+i, num_vecs, m_num_classes));
00133 
00134 #ifdef DEBUG_QDA
00135     CMath::display_matrix(norm2.vector, num_vecs, m_num_classes, "norm2");
00136     CMath::display_vector(out->get_labels().vector, num_vecs, "Labels");
00137 #endif
00138 
00139     return out;
00140 }
00141 
00142 bool CQDA::train_machine(CFeatures* data)
00143 {
00144     if ( !m_labels )
00145         SG_ERROR("No labels allocated in QDA training\n");
00146 
00147     if ( data )
00148     {
00149         if ( !data->has_property(FP_DOT) )
00150             SG_ERROR("Speficied features are not of type CDotFeatures\n");
00151         set_features((CDotFeatures*) data);
00152     }
00153     if ( !m_features )
00154         SG_ERROR("No features allocated in QDA training\n");
00155     SGVector< int32_t > train_labels = ((CMulticlassLabels*) m_labels)->get_int_labels();
00156     if ( !train_labels.vector )
00157         SG_ERROR("No train_labels allocated in QDA training\n");
00158 
00159     cleanup();
00160 
00161     m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00162     m_dim = m_features->get_dim_feature_space();
00163     int32_t num_vec  = m_features->get_num_vectors();
00164     if ( num_vec != train_labels.vlen )
00165         SG_ERROR("Dimension mismatch between features and labels in QDA training");
00166 
00167     int32_t* class_idxs = SG_MALLOC(int32_t, num_vec*m_num_classes);
00168     // number of examples of each class
00169     int32_t* class_nums = SG_MALLOC(int32_t, m_num_classes);
00170     memset(class_nums, 0, m_num_classes*sizeof(int32_t));
00171     int32_t class_idx;
00172     int32_t i, j, k;
00173     for ( i = 0 ; i < train_labels.vlen ; ++i )
00174     {
00175         class_idx = train_labels.vector[i];
00176 
00177         if ( class_idx < 0 || class_idx >= m_num_classes )
00178         {
00179             SG_ERROR("found label out of {0, 1, 2, ..., num_classes-1}...");
00180             return false;
00181         }
00182         else
00183         {
00184             class_idxs[ class_idx*num_vec + class_nums[class_idx]++ ] = i;
00185         }
00186     }
00187 
00188     for ( i = 0 ; i < m_num_classes ; ++i )
00189     {
00190         if ( class_nums[i] <= 0 )
00191         {
00192             SG_ERROR("What? One class with no elements\n");
00193             return false;
00194         }
00195     }
00196 
00197     if ( m_store_covs )
00198     {
00199         // cov_dims will be free in m_covs.destroy_ndarray()
00200         index_t * cov_dims = SG_MALLOC(index_t, 3);
00201         cov_dims[0] = m_dim;
00202         cov_dims[1] = m_dim;
00203         cov_dims[2] = m_num_classes;
00204         m_covs = SGNDArray< float64_t >(cov_dims, 3);
00205     }
00206 
00207     m_means = SGMatrix< float64_t >(m_dim, m_num_classes, true);
00208     SGMatrix< float64_t > scalings  = SGMatrix< float64_t >(m_dim, m_num_classes);
00209 
00210     // rot_dims will be freed in rotations.destroy_ndarray()
00211     index_t* rot_dims = SG_MALLOC(index_t, 3);
00212     rot_dims[0] = m_dim;
00213     rot_dims[1] = m_dim;
00214     rot_dims[2] = m_num_classes;
00215     SGNDArray< float64_t > rotations = SGNDArray< float64_t >(rot_dims, 3);
00216 
00217     CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features;
00218 
00219     m_means.zero();
00220 
00221     int32_t vlen;
00222     bool vfree;
00223     float64_t* vec;
00224     for ( k = 0 ; k < m_num_classes ; ++k )
00225     {
00226         SGMatrix< float64_t > buffer(class_nums[k], m_dim);
00227         for ( i = 0 ; i < class_nums[k] ; ++i )
00228         {
00229             vec = rf->get_feature_vector(class_idxs[k*num_vec + i], vlen, vfree);
00230             ASSERT(vec);
00231 
00232             for ( j = 0 ; j < vlen ; ++j )
00233             {
00234                 m_means[k*m_dim + j] += vec[j];
00235                 buffer[i + j*class_nums[k]] = vec[j];
00236             }
00237 
00238             rf->free_feature_vector(vec, class_idxs[k*num_vec + i], vfree);
00239         }
00240 
00241         for ( j = 0 ; j < m_dim ; ++j )
00242             m_means[k*m_dim + j] /= class_nums[k];
00243 
00244         for ( i = 0 ; i < class_nums[k] ; ++i )
00245             for ( j = 0 ; j < m_dim ; ++j )
00246                 buffer[i + j*class_nums[k]] -= m_means[k*m_dim + j];
00247 
00248         /* calling external lib, buffer = U * S * V^T, U is not interesting here */
00249         char jobu = 'N', jobvt = 'A';
00250         int m = class_nums[k], n = m_dim;
00251         int lda = m, ldu = m, ldvt = n;
00252         int info = -1;
00253         float64_t * col = scalings.get_column_vector(k);
00254         float64_t * rot_mat = rotations.get_matrix(k);
00255 
00256         wrap_dgesvd(jobu, jobvt, m, n, buffer.matrix, lda, col, NULL, ldu,
00257             rot_mat, ldvt, &info);
00258         ASSERT(info == 0);
00259         buffer=SGMatrix<float64_t>();
00260 
00261         SGVector<float64_t>::vector_multiply(col, col, col, m_dim);
00262         SGVector<float64_t>::scale_vector(1.0/(m-1), col, m_dim);
00263         rotations.transpose_matrix(k);
00264 
00265         if ( m_store_covs )
00266         {
00267             SGMatrix< float64_t > M(n ,n);
00268 
00269             M.matrix = SGVector<float64_t>::clone_vector(rot_mat, n*n);
00270             for ( i = 0 ; i < m_dim ; ++i )
00271                 for ( j = 0 ; j < m_dim ; ++j )
00272                     M[i + j*m_dim] *= scalings[k*m_dim + j];
00273 
00274             cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, n, n, n, 1.0,
00275                 M.matrix, n, rot_mat, n, 0.0, m_covs.get_matrix(k), n);
00276         }
00277     }
00278 
00279     /* Computation of terms required for classification */
00280 
00281     SGVector< float32_t > sinvsqrt(m_dim);
00282 
00283     // M_dims will be freed in m_M.destroy_ndarray()
00284     index_t* M_dims = SG_MALLOC(index_t, 3);
00285     M_dims[0] = m_dim;
00286     M_dims[1] = m_dim;
00287     M_dims[2] = m_num_classes;
00288     m_M = SGNDArray< float64_t >(M_dims, 3);
00289 
00290     m_slog = SGVector< float32_t >(m_num_classes);
00291     m_slog.zero();
00292 
00293     index_t idx = 0;
00294     for ( k = 0 ; k < m_num_classes ; ++k )
00295     {
00296         for ( j = 0 ; j < m_dim ; ++j )
00297         {
00298             sinvsqrt[j] = 1.0 / CMath::sqrt(scalings[k*m_dim + j]);
00299             m_slog[k]  += CMath::log(scalings[k*m_dim + j]);
00300         }
00301 
00302         for ( i = 0 ; i < m_dim ; ++i )
00303             for ( j = 0 ; j < m_dim ; ++j )
00304             {
00305                 idx = k*m_dim*m_dim + i + j*m_dim;
00306                 m_M[idx] = rotations[idx] * sinvsqrt[j];
00307             }
00308     }
00309 
00310 #ifdef DEBUG_QDA
00311     SG_PRINT(">>> QDA machine trained with %d classes\n", m_num_classes);
00312 
00313     SG_PRINT("\n>>> Displaying means ...\n");
00314     CMath::display_matrix(m_means.matrix, m_dim, m_num_classes);
00315 
00316     SG_PRINT("\n>>> Displaying scalings ...\n");
00317     CMath::display_matrix(scalings.matrix, m_dim, m_num_classes);
00318 
00319     SG_PRINT("\n>>> Displaying rotations ... \n");
00320     for ( k = 0 ; k < m_num_classes ; ++k )
00321         CMath::display_matrix(rotations.get_matrix(k), m_dim, m_dim);
00322 
00323     SG_PRINT("\n>>> Displaying sinvsqrt ... \n");
00324     sinvsqrt.display_vector();
00325 
00326     SG_PRINT("\n>>> Diplaying m_M matrices ... \n");
00327     for ( k = 0 ; k < m_num_classes ; ++k )
00328         CMath::display_matrix(m_M.get_matrix(k), m_dim, m_dim);
00329 
00330     SG_PRINT("\n>>> Exit DEBUG_QDA\n");
00331 #endif
00332 
00333     SG_FREE(class_idxs);
00334     SG_FREE(class_nums);
00335     return true;
00336 }
00337 
00338 #endif /* HAVE_LAPACK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation