KNN.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006 Christian Gehl
00008  * Written (W) 2006-2009 Soeren Sonnenburg
00009  * Written (W) 2011 Sergey Lisitsyn
00010  * Written (W) 2012 Fernando José Iglesias García, cover tree support
00011  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00012  */
00013 
00014 #include <shogun/multiclass/KNN.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/MulticlassLabels.h>
00017 #include <shogun/mathematics/Math.h>
00018 #include <shogun/lib/Signal.h>
00019 #include <shogun/lib/JLCoverTree.h>
00020 #include <shogun/lib/Time.h>
00021 #include <shogun/base/Parameter.h>
00022 
00023 //#define BENCHMARK_KNN
00024 //#define DEBUG_KNN
00025 
00026 using namespace shogun;
00027 
00028 CKNN::CKNN()
00029 : CDistanceMachine()
00030 {
00031     init();
00032 }
00033 
00034 CKNN::CKNN(int32_t k, CDistance* d, CLabels* trainlab)
00035 : CDistanceMachine()
00036 {
00037     init();
00038 
00039     m_k=k;
00040 
00041     ASSERT(d);
00042     ASSERT(trainlab);
00043 
00044     set_distance(d);
00045     set_labels(trainlab);
00046     m_train_labels.vlen=trainlab->get_num_labels();
00047 }
00048 
00049 void CKNN::init()
00050 {
00051     /* do not store model features by default (CDistanceMachine::apply(...) is
00052      * overwritten */
00053     set_store_model_features(false);
00054 
00055     m_k=3;
00056     m_q=1.0;
00057     m_use_covertree=false;
00058     m_num_classes=0;
00059 
00060     /* use the method classify_multiply_k to experiment with different values 
00061      * of k */
00062     SG_ADD(&m_k, "m_k", "Parameter k", MS_NOT_AVAILABLE);
00063     SG_ADD(&m_q, "m_q", "Parameter q", MS_AVAILABLE);
00064     SG_ADD(&m_use_covertree, "m_use_covertree", "Parameter use_covertree", MS_NOT_AVAILABLE);
00065     SG_ADD(&m_num_classes, "m_num_classes", "Number of classes", MS_NOT_AVAILABLE);
00066 }
00067 
00068 CKNN::~CKNN()
00069 {
00070 }
00071 
00072 bool CKNN::train_machine(CFeatures* data)
00073 {
00074     ASSERT(m_labels);
00075     ASSERT(distance);
00076 
00077     if (data)
00078     {
00079         if (m_labels->get_num_labels() != data->get_num_vectors())
00080             SG_ERROR("Number of training vectors does not match number of labels\n");
00081         distance->init(data, data);
00082     }
00083 
00084     SGVector<int32_t> lab=((CMulticlassLabels*) m_labels)->get_int_labels();
00085     m_train_labels.vlen=lab.vlen;
00086     m_train_labels.vector=SGVector<int32_t>::clone_vector(lab.vector, lab.vlen);
00087     ASSERT(m_train_labels.vlen>0);
00088 
00089     int32_t max_class=m_train_labels.vector[0];
00090     int32_t min_class=m_train_labels.vector[0];
00091 
00092     for (int32_t i=1; i<m_train_labels.vlen; i++)
00093     {
00094         max_class=CMath::max(max_class, m_train_labels.vector[i]);
00095         min_class=CMath::min(min_class, m_train_labels.vector[i]);
00096     }
00097 
00098     for (int32_t i=0; i<m_train_labels.vlen; i++)
00099         m_train_labels.vector[i]-=min_class;
00100 
00101     m_min_label=min_class;
00102     m_num_classes=max_class-min_class+1;
00103 
00104     SG_INFO("m_num_classes: %d (%+d to %+d) num_train: %d\n", m_num_classes,
00105             min_class, max_class, m_train_labels.vlen);
00106 
00107     return true;
00108 }
00109 
00110 CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)
00111 {
00112     if (data)
00113         init_distance(data);
00114 
00115     // redirecting to fast (without sorting) classify if k==1
00116     if (m_k == 1)
00117         return classify_NN();
00118 
00119     ASSERT(m_num_classes>0);
00120     ASSERT(distance);
00121     ASSERT(distance->get_num_vec_rhs());
00122 
00123     int32_t num_lab=distance->get_num_vec_rhs();
00124     ASSERT(m_k<=distance->get_num_vec_lhs());
00125 
00126     CMulticlassLabels* output=new CMulticlassLabels(num_lab);
00127 
00128     float64_t* dists   = NULL;
00129     int32_t* train_lab = NULL;
00130 
00131     //distances to train data and working buffer of m_train_labels
00132     if ( ! m_use_covertree )
00133     {
00134         dists=SG_MALLOC(float64_t, m_train_labels.vlen);
00135         train_lab=SG_MALLOC(int32_t, m_train_labels.vlen);
00136     }
00137     else
00138     {
00139         train_lab=SG_MALLOC(int32_t, m_k);
00140     }
00141 
00142     SG_INFO( "%d test examples\n", num_lab);
00143     CSignal::clear_cancel();
00144 
00146     float64_t* classes=SG_MALLOC(float64_t, m_num_classes);
00147 
00148 #ifdef BENCHMARK_KNN
00149     CTime tstart;
00150     float64_t tfinish, tparsed, tcreated, tqueried;
00151 #endif
00152 
00153     if ( ! m_use_covertree )
00154     {
00155         for (int32_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
00156         {
00157             SG_PROGRESS(i, 0, num_lab);
00158 
00159 #ifdef DEBUG_KNN
00160             distances_lhs(dists,0,m_train_labels.vlen-1,i);
00161 
00162             for (int32_t j=0; j<m_train_labels.vlen; j++)
00163                 train_lab[j]=j;
00164 
00165             CMath::qsort_index(dists, train_lab, m_train_labels.vlen);
00166 
00167             SG_PRINT("\nQuick sort query %d\n", i);
00168             for (int32_t j=0; j<m_k; j++)
00169                 SG_PRINT("%d ", train_lab[j]);
00170             SG_PRINT("\n");
00171 #endif
00172 
00173             //lhs idx 1..n and rhs idx i
00174             distances_lhs(dists,0,m_train_labels.vlen-1,i);
00175 
00176             for (int32_t j=0; j<m_train_labels.vlen; j++)
00177                 train_lab[j]=m_train_labels.vector[j];
00178 
00179             //sort the distance vector for test example j to all 
00180             //train examples
00181             CMath::qsort_index(dists, train_lab, m_train_labels.vlen);
00182 
00183             // Get the index of the 'nearest' class
00184             int32_t out_idx = choose_class(classes, train_lab);
00185             output->set_label(i, out_idx + m_min_label);
00186         }
00187 
00188 #ifdef BENCHMARK_KNN
00189         SG_PRINT(">>>> Quick sort applied in %9.4f\n", 
00190                 (tfinish = tstart.cur_time_diff(false)));
00191 #endif
00192     }
00193     else    // Use cover tree
00194     {
00195         // m_q != 1.0 not supported with cover tree because the neighbors
00196         // are not retrieved in increasing order of distance to the query
00197         float64_t old_q = m_q;
00198         if ( old_q != 1.0 )
00199             SG_INFO("q != 1.0 not supported with cover tree, using q = 1\n");
00200 
00201         // From the sets of features (lhs and rhs) stored in distance,
00202         // build arrays of cover tree points
00203         v_array< CJLCoverTreePoint > set_of_points  = 
00204             parse_points(distance, FC_LHS);
00205         v_array< CJLCoverTreePoint > set_of_queries = 
00206             parse_points(distance, FC_RHS);
00207 
00208 #ifdef BENCHMARK_KNN
00209         SG_PRINT(">>>> JL parsed in %9.4f\n",
00210             ( tparsed = tstart.cur_time_diff(false) ) - tfinish);
00211 #endif
00212         // Build the cover trees, one for the test vectors (rhs features) 
00213         // and another for the training vectors (lhs features)
00214         CFeatures* r = distance->replace_rhs( distance->get_lhs() );
00215         node< CJLCoverTreePoint > top = batch_create(set_of_points);
00216         CFeatures* l = distance->replace_lhs(r);
00217         distance->replace_rhs(r);
00218         node< CJLCoverTreePoint > top_query = batch_create(set_of_queries);
00219 
00220 #ifdef BENCHMARK_KNN
00221         SG_PRINT(">>>> Cover trees created in %9.4f\n", 
00222                 (tcreated = tstart.cur_time_diff(false)) - tparsed);
00223 #endif
00224 
00225         // Get the k nearest neighbors to all the test vectors (batch method)
00226         distance->replace_lhs(l);
00227         v_array< v_array< CJLCoverTreePoint > > res;
00228         k_nearest_neighbor(top, top_query, res, m_k);
00229 
00230 #ifdef BENCHMARK_KNN
00231         SG_PRINT(">>>> Query finished in %9.4f\n", 
00232                 (tqueried = tstart.cur_time_diff(false)) - tcreated);
00233 #endif
00234 
00235 #ifdef DEBUG_KNN
00236         SG_PRINT("\nJL Results:\n");
00237         for ( int32_t i = 0 ; i < res.index ; ++i )
00238         {
00239             for ( int32_t j = 0 ; j < res[i].index ; ++j )
00240             {
00241                 printf("%d ", res[i][j].m_index);
00242             }
00243             printf("\n");
00244         }
00245         SG_PRINT("\n");
00246 #endif
00247 
00248         for ( int32_t i = 0 ; i < res.index ; ++i )
00249         {
00250             // Translate from indices to labels of the nearest neighbors
00251             for ( int32_t j = 0; j < m_k; ++j )
00252                 // The first index in res[i] points to the test vector
00253                 train_lab[j] = m_train_labels.vector[ res[i][j+1].m_index ];
00254 
00255             // Get the index of the 'nearest' class
00256             int32_t out_idx = choose_class(classes, train_lab);
00257             output->set_label(res[i][0].m_index, out_idx+m_min_label);
00258         }
00259 
00260         m_q = old_q;
00261 
00262 #ifdef BENCHMARK_KNN
00263         SG_PRINT(">>>> JL applied in %9.4f\n", tstart.cur_time_diff(false));
00264 #endif
00265     }
00266 
00267     SG_FREE(classes);
00268     SG_FREE(train_lab);
00269     if ( ! m_use_covertree )
00270         SG_FREE(dists);
00271 
00272     return output;
00273 }
00274 
00275 CMulticlassLabels* CKNN::classify_NN()
00276 {
00277     ASSERT(distance);
00278     ASSERT(m_num_classes>0);
00279 
00280     int32_t num_lab = distance->get_num_vec_rhs();
00281     ASSERT(num_lab);
00282 
00283     CMulticlassLabels* output = new CMulticlassLabels(num_lab);
00284     float64_t* distances = SG_MALLOC(float64_t, m_train_labels.vlen);
00285 
00286     SG_INFO("%d test examples\n", num_lab);
00287     CSignal::clear_cancel();
00288 
00289     // for each test example
00290     for (int32_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
00291     {
00292         SG_PROGRESS(i,0,num_lab);
00293 
00294         // get distances from i-th test example to 0..num_m_train_labels-1 train examples
00295         distances_lhs(distances,0,m_train_labels.vlen-1,i);
00296         int32_t j;
00297 
00298         // assuming 0th train examples as nearest to i-th test example
00299         int32_t out_idx = 0;
00300         float64_t min_dist = distances[0];
00301 
00302         // searching for nearest neighbor by comparing distances
00303         for (j=0; j<m_train_labels.vlen; j++)
00304         {
00305             if (distances[j]<min_dist)
00306             {
00307                 min_dist = distances[j];
00308                 out_idx = j;
00309             }
00310         }
00311 
00312         // label i-th test example with label of nearest neighbor with out_idx index
00313         output->set_label(i,m_train_labels.vector[out_idx]+m_min_label);
00314     }
00315 
00316     SG_FREE(distances);
00317     return output;
00318 }
00319 
00320 SGMatrix<int32_t> CKNN::classify_for_multiple_k()
00321 {
00322     ASSERT(m_num_classes>0);
00323     ASSERT(distance);
00324     ASSERT(distance->get_num_vec_rhs());
00325 
00326     int32_t num_lab=distance->get_num_vec_rhs();
00327     ASSERT(m_k<=num_lab);
00328 
00329     int32_t* output=SG_MALLOC(int32_t, m_k*num_lab);
00330 
00331     float64_t* dists;
00332     int32_t* train_lab;
00333     //distances to train data and working buffer of m_train_labels
00334     if ( ! m_use_covertree )
00335     {
00336         dists=SG_MALLOC(float64_t, m_train_labels.vlen);
00337         train_lab=SG_MALLOC(int32_t, m_train_labels.vlen);
00338     }
00339     else
00340     {
00341         dists=SG_MALLOC(float64_t, m_k);
00342         train_lab=SG_MALLOC(int32_t, m_k);
00343     }
00344 
00346     int32_t* classes=SG_MALLOC(int32_t, m_num_classes);
00347     
00348     SG_INFO( "%d test examples\n", num_lab);
00349     CSignal::clear_cancel();
00350 
00351     if ( ! m_use_covertree )
00352     {
00353         for (int32_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
00354         {
00355             SG_PROGRESS(i, 0, num_lab);
00356 
00357             // lhs idx 1..n and rhs idx i
00358             distances_lhs(dists,0,m_train_labels.vlen-1,i);
00359             for (int32_t j=0; j<m_train_labels.vlen; j++)
00360                 train_lab[j]=m_train_labels.vector[j];
00361 
00362             //sort the distance vector for test example j to all train examples
00363             //classes[1..k] then holds the classes for minimum distance
00364             CMath::qsort_index(dists, train_lab, m_train_labels.vlen);
00365 
00366             //compute histogram of class outputs of the first k nearest 
00367             //neighbours
00368             for (int32_t j=0; j<m_num_classes; j++)
00369                 classes[j]=0;
00370 
00371             for (int32_t j=0; j<m_k; j++)
00372             {
00373                 classes[train_lab[j]]++;
00374 
00375                 //choose the class that got 'outputted' most often
00376                 int32_t out_idx=0;
00377                 int32_t out_max=0;
00378 
00379                 for (int32_t c=0; c<m_num_classes; c++)
00380                 {
00381                     if (out_max< classes[c])
00382                     {
00383                         out_idx= c;
00384                         out_max= classes[c];
00385                     }
00386                 }
00387                 output[j*num_lab+i]=out_idx+m_min_label;
00388             }
00389         }
00390     }
00391     else
00392     {
00393         // From the sets of features (lhs and rhs) stored in distance,
00394         // build arrays of cover tree points
00395         v_array< CJLCoverTreePoint > set_of_points  = 
00396             parse_points(distance, FC_LHS);
00397         v_array< CJLCoverTreePoint > set_of_queries = 
00398             parse_points(distance, FC_RHS);
00399         
00400         // Build the cover trees, one for the test vectors (rhs features) 
00401         // and another for the training vectors (lhs features)
00402         CFeatures* r = distance->replace_rhs( distance->get_lhs() );
00403         node< CJLCoverTreePoint > top = batch_create(set_of_points);
00404         CFeatures* l = distance->replace_lhs(r);
00405         distance->replace_rhs(r);
00406         node< CJLCoverTreePoint > top_query = batch_create(set_of_queries);
00407 
00408         // Get the k nearest neighbors to all the test vectors (batch method)
00409         distance->replace_lhs(l);
00410         v_array< v_array< CJLCoverTreePoint > > res;
00411         k_nearest_neighbor(top, top_query, res, m_k);
00412 
00413         for ( int32_t i = 0 ; i < res.index ; ++i )
00414         {
00415             // Handle the fact that cover tree doesn't return neighbors
00416             // ordered by distance
00417             
00418             for ( int32_t j = 0 ; j < m_k ; ++j )
00419             {
00420                 // The first index in res[i] points to the test vector
00421                 dists[j]     = distance->distance(res[i][j+1].m_index,
00422                             res[i][0].m_index);
00423                 train_lab[j] = m_train_labels.vector[ 
00424                             res[i][j+1].m_index ];
00425             }
00426 
00427             // Now we get the indices to the neighbors sorted by distance
00428             CMath::qsort_index(dists, train_lab, m_k);
00429 
00430             //compute histogram of class outputs of the first k nearest 
00431             //neighbours
00432             for (int32_t j=0; j<m_num_classes; j++)
00433                 classes[j]=0;
00434 
00435             for (int32_t j=0; j<m_k; j++)
00436             {
00437                 classes[train_lab[j]]++;
00438 
00439                 //choose the class that got 'outputted' most often
00440                 int32_t out_idx=0;
00441                 int32_t out_max=0;
00442 
00443                 for (int32_t c=0; c<m_num_classes; c++)
00444                 {
00445                     if (out_max< classes[c])
00446                     {
00447                         out_idx= c;
00448                         out_max= classes[c];
00449                     }
00450                 }
00451                 output[j*num_lab+res[i][0].m_index]=out_idx+m_min_label;
00452             }
00453 
00454         }
00455     }
00456 
00457     SG_FREE(train_lab);
00458     SG_FREE(classes);
00459     SG_FREE(dists);
00460 
00461     return SGMatrix<int32_t>(output,num_lab,m_k,true);
00462 }
00463 
00464 void CKNN::init_distance(CFeatures* data)
00465 {
00466     if (!distance)
00467         SG_ERROR("No distance assigned!\n");
00468     CFeatures* lhs=distance->get_lhs();
00469     if (!lhs || !lhs->get_num_vectors())
00470     {
00471         SG_UNREF(lhs);
00472         SG_ERROR("No vectors on left hand side\n");
00473     }
00474     distance->init(lhs, data);
00475     SG_UNREF(lhs);
00476 }
00477 
00478 bool CKNN::load(FILE* srcfile)
00479 {
00480     SG_SET_LOCALE_C;
00481     SG_RESET_LOCALE;
00482     return false;
00483 }
00484 
00485 bool CKNN::save(FILE* dstfile)
00486 {
00487     SG_SET_LOCALE_C;
00488     SG_RESET_LOCALE;
00489     return false;
00490 }
00491 
00492 void CKNN::store_model_features()
00493 {
00494     CFeatures* d_lhs=distance->get_lhs();
00495     CFeatures* d_rhs=distance->get_rhs();
00496 
00497     /* copy lhs of underlying distance */
00498     distance->init(d_lhs->duplicate(), d_rhs);
00499 
00500     SG_UNREF(d_lhs);
00501     SG_UNREF(d_rhs);
00502 }
00503 
00504 int32_t CKNN::choose_class(float64_t* classes, int32_t* train_lab)
00505 {
00506     memset(classes, 0, sizeof(float64_t)*m_num_classes);
00507 
00508     float64_t multiplier = m_q;
00509     for (int32_t j=0; j<m_k; j++)
00510     {
00511         classes[train_lab[j]]+= multiplier;
00512         multiplier*= multiplier;
00513     }
00514 
00515     //choose the class that got 'outputted' most often
00516     int32_t out_idx=0;
00517     float64_t out_max=0;
00518 
00519     for (int32_t j=0; j<m_num_classes; j++)
00520     {
00521         if (out_max< classes[j])
00522         {
00523             out_idx= j;
00524             out_max= classes[j];
00525         }
00526     }
00527 
00528     return out_idx;
00529 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation