Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/mathematics/Statistics.h>
00017
00018 using namespace shogun;
00019
00020 CCrossValidation::CCrossValidation()
00021 {
00022 init();
00023 }
00024
00025 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00026 CLabels* labels, CSplittingStrategy* splitting_strategy,
00027 CEvaluation* evaluation_criterium)
00028 {
00029 init();
00030
00031 m_machine=machine;
00032 m_features=features;
00033 m_labels=labels;
00034 m_splitting_strategy=splitting_strategy;
00035 m_evaluation_criterium=evaluation_criterium;
00036
00037 SG_REF(m_machine);
00038 SG_REF(m_features);
00039 SG_REF(m_labels);
00040 SG_REF(m_splitting_strategy);
00041 SG_REF(m_evaluation_criterium);
00042 }
00043
00044 CCrossValidation::~CCrossValidation()
00045 {
00046 SG_UNREF(m_machine);
00047 SG_UNREF(m_features);
00048 SG_UNREF(m_labels);
00049 SG_UNREF(m_splitting_strategy);
00050 SG_UNREF(m_evaluation_criterium);
00051 }
00052
00053 EEvaluationDirection CCrossValidation::get_evaluation_direction()
00054 {
00055 return m_evaluation_criterium->get_evaluation_direction();
00056 }
00057
00058 void CCrossValidation::init()
00059 {
00060 m_machine=NULL;
00061 m_features=NULL;
00062 m_labels=NULL;
00063 m_splitting_strategy=NULL;
00064 m_evaluation_criterium=NULL;
00065 m_num_runs=1;
00066 m_conf_int_alpha=0;
00067
00068 m_parameters->add((CSGObject**) &m_machine, "machine",
00069 "Used learning machine");
00070 m_parameters->add((CSGObject**) &m_features, "features", "Used features");
00071 m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
00072 m_parameters->add((CSGObject**) &m_splitting_strategy,
00073 "splitting_strategy", "Used splitting strategy");
00074 m_parameters->add((CSGObject**) &m_evaluation_criterium,
00075 "evaluation_criterium", "Used evaluation criterium");
00076 m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
00077 m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
00078 "interval");
00079 }
00080
00081 CMachine* CCrossValidation::get_machine() const
00082 {
00083 SG_REF(m_machine);
00084 return m_machine;
00085 }
00086
00087 CrossValidationResult CCrossValidation::evaluate()
00088 {
00089 SGVector<float64_t> results(m_num_runs);
00090
00091 for (index_t i=0; i<m_num_runs; ++i)
00092 results.vector[i]=evaluate_one_run();
00093
00094
00095 CrossValidationResult result;
00096 result.has_conf_int=m_conf_int_alpha!=0;
00097 result.conf_int_alpha=m_conf_int_alpha;
00098
00099 if (result.has_conf_int)
00100 {
00101 result.conf_int_alpha=m_conf_int_alpha;
00102 result.mean=CStatistics::confidence_intervals_mean(results,
00103 result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
00104 }
00105 else
00106 {
00107 result.mean=CStatistics::mean(results);
00108 result.conf_int_low=0;
00109 result.conf_int_up=0;
00110 }
00111
00112 SG_FREE(results.vector);
00113
00114 return result;
00115 }
00116
00117 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00118 {
00119 if (conf_int_alpha<0 || conf_int_alpha>=1)
00120 {
00121 SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00122 "cross-validation\n", conf_int_alpha);
00123 }
00124
00125 m_conf_int_alpha=conf_int_alpha;
00126 }
00127
00128 void CCrossValidation::set_num_runs(int32_t num_runs)
00129 {
00130 if (num_runs<1)
00131 SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00132
00133 m_num_runs=num_runs;
00134 }
00135
00136 float64_t CCrossValidation::evaluate_one_run()
00137 {
00138 index_t num_subsets=m_splitting_strategy->get_num_subsets();
00139 float64_t* results=SG_MALLOC(float64_t, num_subsets);
00140
00141
00142 m_machine->set_labels(m_labels);
00143
00144
00145
00146 m_machine->set_store_model_features(true);
00147
00148
00149 for (index_t i=0; i<num_subsets; ++i)
00150 {
00151
00152 SGVector<index_t> inverse_subset_indices=
00153 m_splitting_strategy->generate_subset_inverse(i);
00154 m_features->set_subset(new CSubset(inverse_subset_indices));
00155
00156
00157 SGVector<index_t> inverse_subset_indices_copy(
00158 inverse_subset_indices.vlen);
00159 memcpy(inverse_subset_indices_copy.vector,
00160 inverse_subset_indices.vector,
00161 inverse_subset_indices.vlen*sizeof(index_t));
00162 m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
00163
00164
00165 m_machine->train(m_features);
00166
00167
00168 SGVector<index_t> subset_indices=
00169 m_splitting_strategy->generate_subset_indices(i);
00170 m_features->set_subset(new CSubset(subset_indices));
00171
00172
00173 CLabels* result_labels=m_machine->apply(m_features);
00174 SG_REF(result_labels);
00175
00176
00177 SGVector<index_t> subset_indices_copy(subset_indices.vlen);
00178 memcpy(subset_indices_copy.vector, subset_indices.vector,
00179 subset_indices.vlen*sizeof(index_t));
00180 m_labels->set_subset(new CSubset(subset_indices_copy));
00181
00182
00183 results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
00184
00185
00186 SG_UNREF(result_labels);
00187 m_features->remove_subset();
00188 m_labels->remove_subset();
00189 }
00190
00191
00192 float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
00193
00194
00195 SG_FREE(results);
00196
00197 return mean;
00198 }