CrossValidation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/mathematics/Statistics.h>
00017 
00018 using namespace shogun;
00019 
00020 CCrossValidation::CCrossValidation()
00021 {
00022     init();
00023 }
00024 
00025 CCrossValidation::~CCrossValidation()
00026 {
00027     SG_UNREF(m_machine);
00028     SG_UNREF(m_features);
00029     SG_UNREF(m_labels);
00030     SG_UNREF(m_splitting_strategy);
00031     SG_UNREF(m_evaluation_criterium);
00032 }
00033 
00034 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00035         CLabels* labels, CSplittingStrategy* splitting_strategy,
00036         CEvaluation* evaluation_criterium)
00037 {
00038     init();
00039 
00040     m_machine=machine;
00041     m_features=features;
00042     m_labels=labels;
00043     m_splitting_strategy=splitting_strategy;
00044     m_evaluation_criterium=evaluation_criterium;
00045 
00046     SG_REF(m_machine);
00047     SG_REF(m_features);
00048     SG_REF(m_labels);
00049     SG_REF(m_splitting_strategy);
00050     SG_REF(m_evaluation_criterium);
00051 }
00052 
00053 void CCrossValidation::init()
00054 {
00055     m_machine=NULL;
00056     m_features=NULL;
00057     m_labels=NULL;
00058     m_splitting_strategy=NULL;
00059     m_evaluation_criterium=NULL;
00060     m_num_runs=1;
00061     m_conf_int_alpha=0;
00062 
00063     m_parameters->add((CSGObject**) &m_machine, "machine",
00064             "Used learning machine");
00065     m_parameters->add((CSGObject**) &m_features, "features", "Used features");
00066     m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
00067     m_parameters->add((CSGObject**) &m_splitting_strategy,
00068             "splitting_strategy", "Used splitting strategy");
00069     m_parameters->add((CSGObject**) &m_evaluation_criterium,
00070             "evaluation_criterium", "Used evaluation criterium");
00071     m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
00072     m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
00073             "interval");
00074 }
00075 
00076 CMachine* CCrossValidation::get_machine() const
00077 {
00078     SG_REF(m_machine);
00079     return m_machine;
00080 }
00081 
00082 CrossValidationResult CCrossValidation::evaluate()
00083 {
00084     SGVector<float64_t> results(m_num_runs);
00085 
00086     for (index_t i=0; i<m_num_runs; ++i)
00087         results.vector[i]=evaluate_one_run();
00088 
00089     /* construct evaluation result */
00090     CrossValidationResult result;
00091     result.has_conf_int=m_conf_int_alpha!=0;
00092     result.conf_int_alpha=m_conf_int_alpha;
00093 
00094     if (result.has_conf_int)
00095     {
00096         result.conf_int_alpha=m_conf_int_alpha;
00097         result.mean=CStatistics::confidence_intervals_mean(results,
00098                 result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
00099     }
00100     else
00101     {
00102         result.mean=CStatistics::mean(results);
00103         result.conf_int_low=0;
00104         result.conf_int_up=0;
00105     }
00106 
00107     SG_FREE(results.vector);
00108 
00109     return result;
00110 }
00111 
00112 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00113 {
00114     if (conf_int_alpha<0 || conf_int_alpha>=1)
00115     {
00116         SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00117                 "cross-validation\n", conf_int_alpha);
00118     }
00119 
00120     m_conf_int_alpha=conf_int_alpha;
00121 }
00122 
00123 void CCrossValidation::set_num_runs(int32_t num_runs)
00124 {
00125     if (num_runs<1)
00126         SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00127 
00128     m_num_runs=num_runs;
00129 }
00130 
00131 float64_t CCrossValidation::evaluate_one_run()
00132 {
00133     index_t num_subsets=m_splitting_strategy->get_num_subsets();
00134     float64_t* results=SG_MALLOC(float64_t, num_subsets);
00135 
00136     /* set labels to machine */
00137     m_machine->set_labels(m_labels);
00138 
00139     /* tell machine to store model internally
00140      * (otherwise changing subset of features will kaboom the classifier) */
00141     m_machine->set_store_model_features(true);
00142 
00143     /* do actual cross-validation */
00144     for (index_t i=0; i<num_subsets; ++i)
00145     {
00146         /* set feature subset for training */
00147         SGVector<index_t> inverse_subset_indices=
00148                 m_splitting_strategy->generate_subset_inverse(i);
00149         m_features->set_subset(new CSubset(inverse_subset_indices));
00150 
00151         /* set label subset for training (copy data before) */
00152         SGVector<index_t> inverse_subset_indices_copy(
00153                 inverse_subset_indices.vlen);
00154         memcpy(inverse_subset_indices_copy.vector,
00155                 inverse_subset_indices.vector,
00156                 inverse_subset_indices.vlen*sizeof(index_t));
00157         m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
00158 
00159         /* train machine on training features */
00160         m_machine->train(m_features);
00161 
00162         /* set feature subset for testing (subset method that stores pointer) */
00163         SGVector<index_t> subset_indices=
00164                 m_splitting_strategy->generate_subset_indices(i);
00165         m_features->set_subset(new CSubset(subset_indices));
00166 
00167         /* apply machine to test features */
00168         CLabels* result_labels=m_machine->apply(m_features);
00169         SG_REF(result_labels);
00170 
00171         /* set label subset for testing (copy data before) */
00172         SGVector<index_t> subset_indices_copy(subset_indices.vlen);
00173         memcpy(subset_indices_copy.vector, subset_indices.vector,
00174                 subset_indices.vlen*sizeof(index_t));
00175         m_labels->set_subset(new CSubset(subset_indices_copy));
00176 
00177         /* evaluate */
00178         results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
00179 
00180         /* clean up, reset subsets */
00181         SG_UNREF(result_labels);
00182         m_features->remove_subset();
00183         m_labels->remove_subset();
00184     }
00185 
00186     /* build arithmetic mean of results */
00187     float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
00188 
00189     /* clean up */
00190     SG_FREE(results);
00191 
00192     return mean;
00193 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation