00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/CrossValidation.h>
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/evaluation/Evaluation.h>
00014 #include <shogun/evaluation/SplittingStrategy.h>
00015 #include <shogun/base/Parameter.h>
00016 #include <shogun/base/ParameterMap.h>
00017 #include <shogun/mathematics/Statistics.h>
00018 #include <shogun/evaluation/CrossValidationOutput.h>
00019 #include <shogun/lib/List.h>
00020
00021 using namespace shogun;
00022
00023 CCrossValidation::CCrossValidation()
00024 {
00025 init();
00026 }
00027
00028 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features,
00029 CLabels* labels, CSplittingStrategy* splitting_strategy,
00030 CEvaluation* evaluation_criterion, bool autolock) :
00031 CMachineEvaluation(machine, features, labels, splitting_strategy,
00032 evaluation_criterion, autolock)
00033 {
00034 init();
00035 }
00036
00037 CCrossValidation::CCrossValidation(CMachine* machine, CLabels* labels,
00038 CSplittingStrategy* splitting_strategy,
00039 CEvaluation* evaluation_criterion, bool autolock) :
00040 CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
00041 autolock)
00042 {
00043 init();
00044 }
00045
00046 CCrossValidation::~CCrossValidation()
00047 {
00048 SG_UNREF(m_xval_outputs);
00049 }
00050
00051 void CCrossValidation::init()
00052 {
00053 m_num_runs=1;
00054 m_conf_int_alpha=0;
00055
00056
00057 m_xval_outputs=new CList(true);
00058
00059 SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
00060 MS_NOT_AVAILABLE);
00061 SG_ADD(&m_conf_int_alpha, "conf_int_alpha", "alpha-value "
00062 "of confidence interval", MS_NOT_AVAILABLE);
00063 SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
00064 "classes for intermediade cross-validation results",
00065 MS_NOT_AVAILABLE);
00066 }
00067
00068 CEvaluationResult* CCrossValidation::evaluate()
00069 {
00070 SG_DEBUG("entering %s::evaluate()\n", get_name());
00071
00072 REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
00073 "attached\n", get_name());
00074
00075 REQUIRE(m_features, "%s::evaluate() is only possible if features are "
00076 "attached\n", get_name());
00077
00078 REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
00079 "attached\n", get_name());
00080
00081
00082 if (m_do_unlock)
00083 {
00084 m_machine->data_unlock();
00085 m_do_unlock=false;
00086 }
00087
00088
00089 m_machine->set_labels(m_labels);
00090
00091 if (m_autolock)
00092 {
00093
00094 if (m_machine->supports_locking())
00095 {
00096
00097 if (!m_machine->is_data_locked())
00098 {
00099 m_machine->data_lock(m_labels, m_features);
00100 m_do_unlock=true;
00101 }
00102 }
00103 else
00104 {
00105 SG_WARNING("%s does not support locking. Autolocking is skipped. "
00106 "Set autolock flag to false to get rid of warning.\n",
00107 m_machine->get_name());
00108 }
00109 }
00110
00111 SGVector<float64_t> results(m_num_runs);
00112
00113
00114 CCrossValidationOutput* current=(CCrossValidationOutput*)
00115 m_xval_outputs->get_first_element();
00116 while (current)
00117 {
00118 current->init_num_runs(m_num_runs);
00119 current->init_num_folds(m_splitting_strategy->get_num_subsets());
00120 current->init_expose_labels(m_labels);
00121 current->post_init();
00122 SG_UNREF(current);
00123 current=(CCrossValidationOutput*)
00124 m_xval_outputs->get_next_element();
00125 }
00126
00127
00128 SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs);
00129 for (index_t i=0; i <m_num_runs; ++i)
00130 {
00131
00132
00133 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00134 while (current)
00135 {
00136 current->update_run_index(i);
00137 SG_UNREF(current);
00138 current=(CCrossValidationOutput*)
00139 m_xval_outputs->get_next_element();
00140 }
00141
00142 SG_DEBUG("entering cross-validation run %d \n", i);
00143 results[i]=evaluate_one_run();
00144 SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i]);
00145 }
00146
00147
00148 CCrossValidationResult* result = new CCrossValidationResult();
00149 result->has_conf_int=m_conf_int_alpha != 0;
00150 result->conf_int_alpha=m_conf_int_alpha;
00151
00152 if (result->has_conf_int)
00153 {
00154 result->conf_int_alpha=m_conf_int_alpha;
00155 result->mean=CStatistics::confidence_intervals_mean(results,
00156 result->conf_int_alpha, result->conf_int_low, result->conf_int_up);
00157 }
00158 else
00159 {
00160 result->mean=CStatistics::mean(results);
00161 result->conf_int_low=0;
00162 result->conf_int_up=0;
00163 }
00164
00165
00166 if (m_machine->is_data_locked() && m_do_unlock)
00167 {
00168 m_machine->data_unlock();
00169 m_do_unlock=false;
00170 }
00171
00172 SG_DEBUG("leaving %s::evaluate()\n", get_name());
00173
00174 SG_REF(result);
00175 return result;
00176 }
00177
00178 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha)
00179 {
00180 if (conf_int_alpha <0 || conf_int_alpha>= 1) {
00181 SG_ERROR("%f is an illegal alpha-value for confidence interval of "
00182 "cross-validation\n", conf_int_alpha);
00183 }
00184
00185 if (m_num_runs==1)
00186 {
00187 SG_WARNING("Confidence interval for Cross-Validation only possible"
00188 " when number of runs is >1, ignoring.\n");
00189 }
00190 else
00191 m_conf_int_alpha=conf_int_alpha;
00192 }
00193
00194 void CCrossValidation::set_num_runs(int32_t num_runs)
00195 {
00196 if (num_runs <1)
00197 SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
00198
00199 m_num_runs=num_runs;
00200 }
00201
00202 float64_t CCrossValidation::evaluate_one_run()
00203 {
00204 SG_DEBUG("entering %s::evaluate_one_run()\n", get_name());
00205 index_t num_subsets=m_splitting_strategy->get_num_subsets();
00206
00207 SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets);
00208
00209
00210 m_splitting_strategy->build_subsets();
00211
00212
00213 SGVector<float64_t> results(num_subsets);
00214
00215
00216 if (m_machine->is_data_locked())
00217 {
00218 SG_DEBUG("starting locked evaluation\n", get_name());
00219
00220 for (index_t i=0; i <num_subsets; ++i)
00221 {
00222
00223 CCrossValidationOutput* current=(CCrossValidationOutput*)
00224 m_xval_outputs->get_first_element();
00225 while (current)
00226 {
00227 current->update_fold_index(i);
00228 SG_UNREF(current);
00229 current=(CCrossValidationOutput*)
00230 m_xval_outputs->get_next_element();
00231 }
00232
00233
00234 SGVector<index_t> inverse_subset_indices =
00235 m_splitting_strategy->generate_subset_inverse(i);
00236
00237
00238 m_machine->train_locked(inverse_subset_indices);
00239
00240
00241 SGVector<index_t> subset_indices =
00242 m_splitting_strategy->generate_subset_indices(i);
00243
00244
00245 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00246 while (current)
00247 {
00248 current->update_train_indices(inverse_subset_indices, "\t");
00249 current->update_trained_machine(m_machine, "\t");
00250 SG_UNREF(current);
00251 current=(CCrossValidationOutput*)
00252 m_xval_outputs->get_next_element();
00253 }
00254
00255
00256 CLabels* result_labels=m_machine->apply_locked(subset_indices);
00257 SG_REF(result_labels);
00258
00259
00260 m_labels->add_subset(subset_indices);
00261
00262
00263 m_evaluation_criterion->set_indices(subset_indices);
00264 results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
00265
00266
00267 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00268 while (current)
00269 {
00270 current->update_test_indices(subset_indices, "\t");
00271 current->update_test_result(result_labels, "\t");
00272 current->update_test_true_result(m_labels, "\t");
00273 current->post_update_results();
00274 current->update_evaluation_result(results[i], "\t");
00275 SG_UNREF(current);
00276 current=(CCrossValidationOutput*)
00277 m_xval_outputs->get_next_element();
00278 }
00279
00280
00281 m_labels->remove_subset();
00282
00283
00284 SG_UNREF(result_labels);
00285
00286 SG_DEBUG("done locked evaluation\n", get_name());
00287 }
00288 }
00289 else
00290 {
00291 SG_DEBUG("starting unlocked evaluation\n", get_name());
00292
00293
00294 m_machine->set_store_model_features(true);
00295
00296
00297 for (index_t i=0; i <num_subsets; ++i)
00298 {
00299
00300 CCrossValidationOutput* current=(CCrossValidationOutput*)
00301 m_xval_outputs->get_first_element();
00302 while (current)
00303 {
00304 current->update_fold_index(i);
00305 SG_UNREF(current);
00306 current=(CCrossValidationOutput*)
00307 m_xval_outputs->get_next_element();
00308 }
00309
00310
00311 SGVector<index_t> inverse_subset_indices=
00312 m_splitting_strategy->generate_subset_inverse(i);
00313 m_features->add_subset(inverse_subset_indices);
00314 for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
00315 {
00316 CPreprocessor* preprocessor = m_features->get_preprocessor(p);
00317 preprocessor->init(m_features);
00318 SG_UNREF(preprocessor);
00319 }
00320
00321
00322 m_labels->add_subset(inverse_subset_indices);
00323
00324 SG_DEBUG("training set %d:\n", i);
00325 if (io->get_loglevel()==MSG_DEBUG)
00326 {
00327 SGVector<index_t>::display_vector(inverse_subset_indices.vector,
00328 inverse_subset_indices.vlen, "training indices");
00329 }
00330
00331
00332 SG_DEBUG("starting training\n");
00333 m_machine->train(m_features);
00334 SG_DEBUG("finished training\n");
00335
00336
00337 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00338 while (current)
00339 {
00340 current->update_train_indices(inverse_subset_indices, "\t");
00341 current->update_trained_machine(m_machine, "\t");
00342 SG_UNREF(current);
00343 current=(CCrossValidationOutput*)
00344 m_xval_outputs->get_next_element();
00345 }
00346
00347 m_features->remove_subset();
00348 m_labels->remove_subset();
00349
00350
00351 SGVector<index_t> subset_indices =
00352 m_splitting_strategy->generate_subset_indices(i);
00353 m_features->add_subset(subset_indices);
00354
00355
00356 m_labels->add_subset(subset_indices);
00357
00358 SG_DEBUG("test set %d:\n", i);
00359 if (io->get_loglevel()==MSG_DEBUG)
00360 {
00361 SGVector<index_t>::display_vector(subset_indices.vector,
00362 subset_indices.vlen, "test indices");
00363 }
00364
00365
00366 SG_DEBUG("starting evaluation\n");
00367 SG_DEBUG("%p\n", m_features);
00368 CLabels* result_labels=m_machine->apply(m_features);
00369 SG_DEBUG("finished evaluation\n");
00370 m_features->remove_subset();
00371 SG_REF(result_labels);
00372
00373
00374 results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
00375 SG_DEBUG("result on fold %d is %f\n", i, results[i]);
00376
00377
00378 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element();
00379 while (current)
00380 {
00381 current->update_test_indices(subset_indices, "\t");
00382 current->update_test_result(result_labels, "\t");
00383 current->update_test_true_result(m_labels, "\t");
00384 current->post_update_results();
00385 current->update_evaluation_result(results[i], "\t");
00386 SG_UNREF(current);
00387 current=(CCrossValidationOutput*)
00388 m_xval_outputs->get_next_element();
00389 }
00390
00391
00392 SG_UNREF(result_labels);
00393 m_labels->remove_subset();
00394 }
00395
00396 SG_DEBUG("done unlocked evaluation\n", get_name());
00397 }
00398
00399
00400 float64_t mean=CStatistics::mean(results);
00401
00402 SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name());
00403 return mean;
00404 }
00405
00406 void CCrossValidation::add_cross_validation_output(
00407 CCrossValidationOutput* cross_validation_output)
00408 {
00409 m_xval_outputs->append_element(cross_validation_output);
00410 }