CrossValidation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/mathematics/Statistics.h>
00017 
00018 using namespace shogun;
00019 
00020 CCrossValidation::CCrossValidation()
00021 {
00022     init();
00023 }
00024 
00025 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00026         CLabels* labels, CSplittingStrategy* splitting_strategy,
00027         CEvaluation* evaluation_criterium)
00028 {
00029     init();
00030 
00031     m_machine=machine;
00032     m_features=features;
00033     m_labels=labels;
00034     m_splitting_strategy=splitting_strategy;
00035     m_evaluation_criterium=evaluation_criterium;
00036 
00037     SG_REF(m_machine);
00038     SG_REF(m_features);
00039     SG_REF(m_labels);
00040     SG_REF(m_splitting_strategy);
00041     SG_REF(m_evaluation_criterium);
00042 }
00043 
00044 CCrossValidation::~CCrossValidation()
00045 {
00046     SG_UNREF(m_machine);
00047     SG_UNREF(m_features);
00048     SG_UNREF(m_labels);
00049     SG_UNREF(m_splitting_strategy);
00050     SG_UNREF(m_evaluation_criterium);
00051 }
00052 
00053 EEvaluationDirection CCrossValidation::get_evaluation_direction()
00054 {
00055     return m_evaluation_criterium->get_evaluation_direction();
00056 }
00057 
00058 void CCrossValidation::init()
00059 {
00060     m_machine=NULL;
00061     m_features=NULL;
00062     m_labels=NULL;
00063     m_splitting_strategy=NULL;
00064     m_evaluation_criterium=NULL;
00065     m_num_runs=1;
00066     m_conf_int_alpha=0;
00067 
00068     m_parameters->add((CSGObject**) &m_machine, "machine",
00069             "Used learning machine");
00070     m_parameters->add((CSGObject**) &m_features, "features", "Used features");
00071     m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
00072     m_parameters->add((CSGObject**) &m_splitting_strategy,
00073             "splitting_strategy", "Used splitting strategy");
00074     m_parameters->add((CSGObject**) &m_evaluation_criterium,
00075             "evaluation_criterium", "Used evaluation criterium");
00076     m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
00077     m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
00078             "interval");
00079 }
00080 
00081 CMachine* CCrossValidation::get_machine() const
00082 {
00083     SG_REF(m_machine);
00084     return m_machine;
00085 }
00086 
00087 CrossValidationResult CCrossValidation::evaluate()
00088 {
00089     SGVector<float64_t> results(m_num_runs);
00090 
00091     for (index_t i=0; i<m_num_runs; ++i)
00092         results.vector[i]=evaluate_one_run();
00093 
00094     /* construct evaluation result */
00095     CrossValidationResult result;
00096     result.has_conf_int=m_conf_int_alpha!=0;
00097     result.conf_int_alpha=m_conf_int_alpha;
00098 
00099     if (result.has_conf_int)
00100     {
00101         result.conf_int_alpha=m_conf_int_alpha;
00102         result.mean=CStatistics::confidence_intervals_mean(results,
00103                 result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
00104     }
00105     else
00106     {
00107         result.mean=CStatistics::mean(results);
00108         result.conf_int_low=0;
00109         result.conf_int_up=0;
00110     }
00111 
00112     SG_FREE(results.vector);
00113 
00114     return result;
00115 }
00116 
00117 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00118 {
00119     if (conf_int_alpha<0 || conf_int_alpha>=1)
00120     {
00121         SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00122                 "cross-validation\n", conf_int_alpha);
00123     }
00124 
00125     m_conf_int_alpha=conf_int_alpha;
00126 }
00127 
00128 void CCrossValidation::set_num_runs(int32_t num_runs)
00129 {
00130     if (num_runs<1)
00131         SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00132 
00133     m_num_runs=num_runs;
00134 }
00135 
00136 float64_t CCrossValidation::evaluate_one_run()
00137 {
00138     index_t num_subsets=m_splitting_strategy->get_num_subsets();
00139     float64_t* results=SG_MALLOC(float64_t, num_subsets);
00140 
00141     /* set labels to machine */
00142     m_machine->set_labels(m_labels);
00143 
00144     /* tell machine to store model internally
00145      * (otherwise changing subset of features will kaboom the classifier) */
00146     m_machine->set_store_model_features(true);
00147 
00148     /* do actual cross-validation */
00149     for (index_t i=0; i<num_subsets; ++i)
00150     {
00151         /* set feature subset for training */
00152         SGVector<index_t> inverse_subset_indices=
00153                 m_splitting_strategy->generate_subset_inverse(i);
00154         m_features->set_subset(new CSubset(inverse_subset_indices));
00155 
00156         /* set label subset for training (copy data before) */
00157         SGVector<index_t> inverse_subset_indices_copy(
00158                 inverse_subset_indices.vlen);
00159         memcpy(inverse_subset_indices_copy.vector,
00160                 inverse_subset_indices.vector,
00161                 inverse_subset_indices.vlen*sizeof(index_t));
00162         m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
00163 
00164         /* train machine on training features */
00165         m_machine->train(m_features);
00166 
00167         /* set feature subset for testing (subset method that stores pointer) */
00168         SGVector<index_t> subset_indices=
00169                 m_splitting_strategy->generate_subset_indices(i);
00170         m_features->set_subset(new CSubset(subset_indices));
00171 
00172         /* apply machine to test features */
00173         CLabels* result_labels=m_machine->apply(m_features);
00174         SG_REF(result_labels);
00175 
00176         /* set label subset for testing (copy data before) */
00177         SGVector<index_t> subset_indices_copy(subset_indices.vlen);
00178         memcpy(subset_indices_copy.vector, subset_indices.vector,
00179                 subset_indices.vlen*sizeof(index_t));
00180         m_labels->set_subset(new CSubset(subset_indices_copy));
00181 
00182         /* evaluate */
00183         results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
00184 
00185         /* clean up, reset subsets */
00186         SG_UNREF(result_labels);
00187         m_features->remove_subset();
00188         m_labels->remove_subset();
00189     }
00190 
00191     /* build arithmetic mean of results */
00192     float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
00193 
00194     /* clean up */
00195     SG_FREE(results);
00196 
00197     return mean;
00198 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation