Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/mathematics/Statistics.h>
00017
00018 using namespace shogun;
00019
00020 CCrossValidation::CCrossValidation()
00021 {
00022 init();
00023 }
00024
00025 CCrossValidation::~CCrossValidation()
00026 {
00027 SG_UNREF(m_machine);
00028 SG_UNREF(m_features);
00029 SG_UNREF(m_labels);
00030 SG_UNREF(m_splitting_strategy);
00031 SG_UNREF(m_evaluation_criterium);
00032 }
00033
00034 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00035 CLabels* labels, CSplittingStrategy* splitting_strategy,
00036 CEvaluation* evaluation_criterium)
00037 {
00038 init();
00039
00040 m_machine=machine;
00041 m_features=features;
00042 m_labels=labels;
00043 m_splitting_strategy=splitting_strategy;
00044 m_evaluation_criterium=evaluation_criterium;
00045
00046 SG_REF(m_machine);
00047 SG_REF(m_features);
00048 SG_REF(m_labels);
00049 SG_REF(m_splitting_strategy);
00050 SG_REF(m_evaluation_criterium);
00051 }
00052
00053 void CCrossValidation::init()
00054 {
00055 m_machine=NULL;
00056 m_features=NULL;
00057 m_labels=NULL;
00058 m_splitting_strategy=NULL;
00059 m_evaluation_criterium=NULL;
00060 m_num_runs=1;
00061 m_conf_int_alpha=0;
00062
00063 m_parameters->add((CSGObject**) &m_machine, "machine",
00064 "Used learning machine");
00065 m_parameters->add((CSGObject**) &m_features, "features", "Used features");
00066 m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels");
00067 m_parameters->add((CSGObject**) &m_splitting_strategy,
00068 "splitting_strategy", "Used splitting strategy");
00069 m_parameters->add((CSGObject**) &m_evaluation_criterium,
00070 "evaluation_criterium", "Used evaluation criterium");
00071 m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions");
00072 m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence "
00073 "interval");
00074 }
00075
00076 CMachine* CCrossValidation::get_machine() const
00077 {
00078 SG_REF(m_machine);
00079 return m_machine;
00080 }
00081
00082 CrossValidationResult CCrossValidation::evaluate()
00083 {
00084 SGVector<float64_t> results(m_num_runs);
00085
00086 for (index_t i=0; i<m_num_runs; ++i)
00087 results.vector[i]=evaluate_one_run();
00088
00089
00090 CrossValidationResult result;
00091 result.has_conf_int=m_conf_int_alpha!=0;
00092 result.conf_int_alpha=m_conf_int_alpha;
00093
00094 if (result.has_conf_int)
00095 {
00096 result.conf_int_alpha=m_conf_int_alpha;
00097 result.mean=CStatistics::confidence_intervals_mean(results,
00098 result.conf_int_alpha, result.conf_int_low, result.conf_int_up);
00099 }
00100 else
00101 {
00102 result.mean=CStatistics::mean(results);
00103 result.conf_int_low=0;
00104 result.conf_int_up=0;
00105 }
00106
00107 SG_FREE(results.vector);
00108
00109 return result;
00110 }
00111
00112 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00113 {
00114 if (conf_int_alpha<0 || conf_int_alpha>=1)
00115 {
00116 SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00117 "cross-validation\n", conf_int_alpha);
00118 }
00119
00120 m_conf_int_alpha=conf_int_alpha;
00121 }
00122
00123 void CCrossValidation::set_num_runs(int32_t num_runs)
00124 {
00125 if (num_runs<1)
00126 SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00127
00128 m_num_runs=num_runs;
00129 }
00130
00131 float64_t CCrossValidation::evaluate_one_run()
00132 {
00133 index_t num_subsets=m_splitting_strategy->get_num_subsets();
00134 float64_t* results=SG_MALLOC(float64_t, num_subsets);
00135
00136
00137 m_machine->set_labels(m_labels);
00138
00139
00140
00141 m_machine->set_store_model_features(true);
00142
00143
00144 for (index_t i=0; i<num_subsets; ++i)
00145 {
00146
00147 SGVector<index_t> inverse_subset_indices=
00148 m_splitting_strategy->generate_subset_inverse(i);
00149 m_features->set_subset(new CSubset(inverse_subset_indices));
00150
00151
00152 SGVector<index_t> inverse_subset_indices_copy(
00153 inverse_subset_indices.vlen);
00154 memcpy(inverse_subset_indices_copy.vector,
00155 inverse_subset_indices.vector,
00156 inverse_subset_indices.vlen*sizeof(index_t));
00157 m_labels->set_subset(new CSubset(inverse_subset_indices_copy));
00158
00159
00160 m_machine->train(m_features);
00161
00162
00163 SGVector<index_t> subset_indices=
00164 m_splitting_strategy->generate_subset_indices(i);
00165 m_features->set_subset(new CSubset(subset_indices));
00166
00167
00168 CLabels* result_labels=m_machine->apply(m_features);
00169 SG_REF(result_labels);
00170
00171
00172 SGVector<index_t> subset_indices_copy(subset_indices.vlen);
00173 memcpy(subset_indices_copy.vector, subset_indices.vector,
00174 subset_indices.vlen*sizeof(index_t));
00175 m_labels->set_subset(new CSubset(subset_indices_copy));
00176
00177
00178 results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels);
00179
00180
00181 SG_UNREF(result_labels);
00182 m_features->remove_subset();
00183 m_labels->remove_subset();
00184 }
00185
00186
00187 float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets));
00188
00189
00190 SG_FREE(results);
00191
00192 return mean;
00193 }