CrossValidation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/base/ParameterMap.h>
00017 #include <shogun/mathematics/Statistics.h>
00018 #include <shogun/evaluation/CrossValidationOutput.h>
00019 #include <shogun/lib/List.h>
00020 
00021 using namespace shogun;
00022 
00023 CCrossValidation::CCrossValidation()
00024 {
00025     init();
00026 }
00027 
00028 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00029         CLabels* labels, CSplittingStrategy* splitting_strategy,
00030         CEvaluation* evaluation_criterion, bool autolock) :
00031         CMachineEvaluation(machine, features, labels, splitting_strategy,
00032         evaluation_criterion, autolock)
00033 {
00034     init();
00035 }
00036 
00037 CCrossValidation::CCrossValidation(CMachine* machine, CLabels* labels,
00038         CSplittingStrategy* splitting_strategy,
00039         CEvaluation* evaluation_criterion, bool autolock) :
00040         CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
00041         autolock)
00042 {
00043     init();
00044 }
00045 
00046 CCrossValidation::~CCrossValidation()
00047 {
00048     SG_UNREF(m_xval_outputs);
00049 }
00050 
00051 void CCrossValidation::init()
00052 {
00053     m_num_runs=1;
00054     m_conf_int_alpha=0;
00055 
00056     /* do reference counting for output objects */
00057     m_xval_outputs=new CList(true);
00058 
00059     SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
00060             MS_NOT_AVAILABLE);
00061     SG_ADD(&m_conf_int_alpha, "conf_int_alpha", "alpha-value "
00062             "of confidence interval", MS_NOT_AVAILABLE);
00063     SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
00064             "classes for intermediade cross-validation results",
00065             MS_NOT_AVAILABLE);
00066 }
00067 
00068 CEvaluationResult* CCrossValidation::evaluate()
00069 {
00070     SG_DEBUG("entering %s::evaluate()\n", get_name());
00071 
00072     REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
00073             "attached\n", get_name());
00074 
00075     REQUIRE(m_features, "%s::evaluate() is only possible if features are "
00076             "attached\n", get_name());
00077 
00078     REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
00079             "attached\n", get_name());
00080 
00081     /* if for some reason the do_unlock_frag is set, unlock */
00082     if (m_do_unlock)
00083     {
00084         m_machine->data_unlock();
00085         m_do_unlock=false;
00086     }
00087 
00088     /* set labels in any case (no locking needs this) */
00089     m_machine->set_labels(m_labels);
00090 
00091     if (m_autolock)
00092     {
00093         /* if machine supports locking try to do so */
00094         if (m_machine->supports_locking())
00095         {
00096             /* only lock if machine is not yet locked */
00097             if (!m_machine->is_data_locked())
00098             {
00099                 m_machine->data_lock(m_labels, m_features);
00100                 m_do_unlock=true;
00101             }
00102         }
00103         else
00104         {
00105             SG_WARNING("%s does not support locking. Autolocking is skipped. "
00106                     "Set autolock flag to false to get rid of warning.\n",
00107                     m_machine->get_name());
00108         }
00109     }
00110 
00111     SGVector<float64_t> results(m_num_runs);
00112 
00113     /* evtl. update xvalidation output class */
00114     CCrossValidationOutput* current=(CCrossValidationOutput*)
00115             m_xval_outputs->get_first_element();
00116     while (current)
00117     {
00118         current->init_num_runs(m_num_runs);
00119         current->init_num_folds(m_splitting_strategy->get_num_subsets());
00120         current->init_expose_labels(m_labels);
00121         current->post_init();
00122         SG_UNREF(current);
00123         current=(CCrossValidationOutput*)
00124                 m_xval_outputs->get_next_element();
00125     }
00126 
00127     /* perform all the x-val runs */
00128     SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs);
00129     for (index_t i=0; i <m_num_runs; ++i)
00130     {
00131 
00132         /* evtl. update xvalidation output class */
00133         current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00134         while (current)
00135         {
00136             current->update_run_index(i);
00137             SG_UNREF(current);
00138             current=(CCrossValidationOutput*)
00139                     m_xval_outputs->get_next_element();
00140         }
00141 
00142         SG_DEBUG("entering cross-validation run %d \n", i);
00143         results[i]=evaluate_one_run();
00144         SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i]);
00145     }
00146 
00147     /* construct evaluation result */
00148     CCrossValidationResult* result = new CCrossValidationResult();
00149     result->has_conf_int=m_conf_int_alpha != 0;
00150     result->conf_int_alpha=m_conf_int_alpha;
00151 
00152     if (result->has_conf_int)
00153     {
00154         result->conf_int_alpha=m_conf_int_alpha;
00155         result->mean=CStatistics::confidence_intervals_mean(results,
00156                 result->conf_int_alpha, result->conf_int_low, result->conf_int_up);
00157     }
00158     else
00159     {
00160         result->mean=CStatistics::mean(results);
00161         result->conf_int_low=0;
00162         result->conf_int_up=0;
00163     }
00164 
00165     /* unlock machine if it was locked in this method */
00166     if (m_machine->is_data_locked() && m_do_unlock)
00167     {
00168         m_machine->data_unlock();
00169         m_do_unlock=false;
00170     }
00171 
00172     SG_DEBUG("leaving %s::evaluate()\n", get_name());
00173 
00174     SG_REF(result);
00175     return result;
00176 }
00177 
00178 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00179 {
00180     if (conf_int_alpha <0 || conf_int_alpha>= 1) {
00181         SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00182         "cross-validation\n", conf_int_alpha);
00183     }
00184 
00185     if (m_num_runs==1)
00186     {
00187         SG_WARNING("Confidence interval for Cross-Validation only possible"
00188                 " when number of runs is >1, ignoring.\n");
00189     }
00190     else
00191         m_conf_int_alpha=conf_int_alpha;
00192 }
00193 
00194 void CCrossValidation::set_num_runs(int32_t num_runs)
00195 {
00196     if (num_runs <1)
00197         SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00198 
00199     m_num_runs=num_runs;
00200 }
00201 
00202 float64_t CCrossValidation::evaluate_one_run()
00203 {
00204     SG_DEBUG("entering %s::evaluate_one_run()\n", get_name());
00205     index_t num_subsets=m_splitting_strategy->get_num_subsets();
00206 
00207     SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets);
00208 
00209     /* build index sets */
00210     m_splitting_strategy->build_subsets();
00211 
00212     /* results array */
00213     SGVector<float64_t> results(num_subsets);
00214 
00215     /* different behavior whether data is locked or not */
00216     if (m_machine->is_data_locked())
00217     {
00218         SG_DEBUG("starting locked evaluation\n", get_name());
00219         /* do actual cross-validation */
00220         for (index_t i=0; i <num_subsets; ++i)
00221         {
00222             /* evtl. update xvalidation output class */
00223             CCrossValidationOutput* current=(CCrossValidationOutput*)
00224                     m_xval_outputs->get_first_element();
00225             while (current)
00226             {
00227                 current->update_fold_index(i);
00228                 SG_UNREF(current);
00229                 current=(CCrossValidationOutput*)
00230                         m_xval_outputs->get_next_element();
00231             }
00232 
00233             /* index subset for training, will be freed below */
00234             SGVector<index_t> inverse_subset_indices =
00235                     m_splitting_strategy->generate_subset_inverse(i);
00236 
00237             /* train machine on training features */
00238             m_machine->train_locked(inverse_subset_indices);
00239 
00240             /* feature subset for testing */
00241             SGVector<index_t> subset_indices =
00242                     m_splitting_strategy->generate_subset_indices(i);
00243 
00244             /* evtl. update xvalidation output class */
00245             current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00246             while (current)
00247             {
00248                 current->update_train_indices(inverse_subset_indices, "\t");
00249                 current->update_trained_machine(m_machine, "\t");
00250                 SG_UNREF(current);
00251                 current=(CCrossValidationOutput*)
00252                         m_xval_outputs->get_next_element();
00253             }
00254 
00255             /* produce output for desired indices */
00256             CLabels* result_labels=m_machine->apply_locked(subset_indices);
00257             SG_REF(result_labels);
00258 
00259             /* set subset for testing labels */
00260             m_labels->add_subset(subset_indices);
00261 
00262             /* evaluate against own labels */
00263             m_evaluation_criterion->set_indices(subset_indices);
00264             results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
00265 
00266             /* evtl. update xvalidation output class */
00267             current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00268             while (current)
00269             {
00270                 current->update_test_indices(subset_indices, "\t");
00271                 current->update_test_result(result_labels, "\t");
00272                 current->update_test_true_result(m_labels, "\t");
00273                 current->post_update_results();
00274                 current->update_evaluation_result(results[i], "\t");
00275                 SG_UNREF(current);
00276                 current=(CCrossValidationOutput*)
00277                         m_xval_outputs->get_next_element();
00278             }
00279 
00280             /* remove subset to prevent side effects */
00281             m_labels->remove_subset();
00282 
00283             /* clean up */
00284             SG_UNREF(result_labels);
00285 
00286             SG_DEBUG("done locked evaluation\n", get_name());
00287         }
00288     }
00289     else
00290     {
00291         SG_DEBUG("starting unlocked evaluation\n", get_name());
00292         /* tell machine to store model internally
00293          * (otherwise changing subset of features will kaboom the classifier) */
00294         m_machine->set_store_model_features(true);
00295 
00296         /* do actual cross-validation */
00297         for (index_t i=0; i <num_subsets; ++i)
00298         {
00299             /* evtl. update xvalidation output class */
00300             CCrossValidationOutput* current=(CCrossValidationOutput*)
00301                     m_xval_outputs->get_first_element();
00302             while (current)
00303             {
00304                 current->update_fold_index(i);
00305                 SG_UNREF(current);
00306                 current=(CCrossValidationOutput*)
00307                         m_xval_outputs->get_next_element();
00308             }
00309 
00310             /* set feature subset for training */
00311             SGVector<index_t> inverse_subset_indices=
00312                     m_splitting_strategy->generate_subset_inverse(i);
00313             m_features->add_subset(inverse_subset_indices);
00314             for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
00315             {
00316                 CPreprocessor* preprocessor = m_features->get_preprocessor(p);
00317                 preprocessor->init(m_features);
00318                 SG_UNREF(preprocessor);
00319             }
00320 
00321             /* set label subset for training */
00322             m_labels->add_subset(inverse_subset_indices);
00323 
00324             SG_DEBUG("training set %d:\n", i);
00325             if (io->get_loglevel()==MSG_DEBUG)
00326             {
00327                 SGVector<index_t>::display_vector(inverse_subset_indices.vector,
00328                         inverse_subset_indices.vlen, "training indices");
00329             }
00330 
00331             /* train machine on training features and remove subset */
00332             SG_DEBUG("starting training\n");
00333             m_machine->train(m_features);
00334             SG_DEBUG("finished training\n");
00335 
00336             /* evtl. update xvalidation output class */
00337             current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00338             while (current)
00339             {
00340                 current->update_train_indices(inverse_subset_indices, "\t");
00341                 current->update_trained_machine(m_machine, "\t");
00342                 SG_UNREF(current);
00343                 current=(CCrossValidationOutput*)
00344                         m_xval_outputs->get_next_element();
00345             }
00346 
00347             m_features->remove_subset();
00348             m_labels->remove_subset();
00349 
00350             /* set feature subset for testing (subset method that stores pointer) */
00351             SGVector<index_t> subset_indices =
00352                     m_splitting_strategy->generate_subset_indices(i);
00353             m_features->add_subset(subset_indices);
00354 
00355             /* set label subset for testing */
00356             m_labels->add_subset(subset_indices);
00357 
00358             SG_DEBUG("test set %d:\n", i);
00359             if (io->get_loglevel()==MSG_DEBUG)
00360             {
00361                 SGVector<index_t>::display_vector(subset_indices.vector,
00362                         subset_indices.vlen, "test indices");
00363             }
00364 
00365             /* apply machine to test features and remove subset */
00366             SG_DEBUG("starting evaluation\n");
00367             SG_DEBUG("%p\n", m_features);
00368             CLabels* result_labels=m_machine->apply(m_features);
00369             SG_DEBUG("finished evaluation\n");
00370             m_features->remove_subset();
00371             SG_REF(result_labels);
00372 
00373             /* evaluate */
00374             results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
00375             SG_DEBUG("result on fold %d is %f\n", i, results[i]);
00376 
00377             /* evtl. update xvalidation output class */
00378             current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00379             while (current)
00380             {
00381                 current->update_test_indices(subset_indices, "\t");
00382                 current->update_test_result(result_labels, "\t");
00383                 current->update_test_true_result(m_labels, "\t");
00384                 current->post_update_results();
00385                 current->update_evaluation_result(results[i], "\t");
00386                 SG_UNREF(current);
00387                 current=(CCrossValidationOutput*)
00388                         m_xval_outputs->get_next_element();
00389             }
00390 
00391             /* clean up, remove subsets */
00392             SG_UNREF(result_labels);
00393             m_labels->remove_subset();
00394         }
00395 
00396         SG_DEBUG("done unlocked evaluation\n", get_name());
00397     }
00398 
00399     /* build arithmetic mean of results */
00400     float64_t mean=CStatistics::mean(results);
00401 
00402     SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name());
00403     return mean;
00404 }
00405 
00406 void CCrossValidation::add_cross_validation_output(
00407             CCrossValidationOutput* cross_validation_output)
00408 {
00409     m_xval_outputs->append_element(cross_validation_output);
00410 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation