SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
19 #include <shogun/lib/List.h>
20 
21 using namespace shogun;
22 
24 {
25  init();
26 }
27 
29  CLabels* labels, CSplittingStrategy* splitting_strategy,
30  CEvaluation* evaluation_criterion, bool autolock) :
31  CMachineEvaluation(machine, features, labels, splitting_strategy,
32  evaluation_criterion, autolock)
33 {
34  init();
35 }
36 
38  CSplittingStrategy* splitting_strategy,
39  CEvaluation* evaluation_criterion, bool autolock) :
40  CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
41  autolock)
42 {
43  init();
44 }
45 
47 {
49 }
50 
51 void CCrossValidation::init()
52 {
53  m_num_runs=1;
55 
56  /* do reference counting for output objects */
57  m_xval_outputs=new CList(true);
58 
59  SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
61  SG_ADD(&m_conf_int_alpha, "conf_int_alpha", "alpha-value "
62  "of confidence interval", MS_NOT_AVAILABLE);
63  SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
64  "classes for intermediade cross-validation results",
66 }
67 
69 {
70  SG_DEBUG("entering %s::evaluate()\n", get_name());
71 
72  REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
73  "attached\n", get_name());
74 
75  REQUIRE(m_features, "%s::evaluate() is only possible if features are "
76  "attached\n", get_name());
77 
78  REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
79  "attached\n", get_name());
80 
81  /* if for some reason the do_unlock_frag is set, unlock */
82  if (m_do_unlock)
83  {
85  m_do_unlock=false;
86  }
87 
88  /* set labels in any case (no locking needs this) */
90 
91  if (m_autolock)
92  {
93  /* if machine supports locking try to do so */
95  {
96  /* only lock if machine is not yet locked */
97  if (!m_machine->is_data_locked())
98  {
100  m_do_unlock=true;
101  }
102  }
103  else
104  {
105  SG_WARNING("%s does not support locking. Autolocking is skipped. "
106  "Set autolock flag to false to get rid of warning.\n",
107  m_machine->get_name());
108  }
109  }
110 
112 
113  /* evtl. update xvalidation output class */
116  while (current)
117  {
118  current->init_num_runs(m_num_runs);
120  current->init_expose_labels(m_labels);
121  current->post_init();
122  SG_UNREF(current);
123  current=(CCrossValidationOutput*)
125  }
126 
127  /* perform all the x-val runs */
128  SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs);
129  for (index_t i=0; i <m_num_runs; ++i)
130  {
131 
132  /* evtl. update xvalidation output class */
134  while (current)
135  {
136  current->update_run_index(i);
137  SG_UNREF(current);
138  current=(CCrossValidationOutput*)
140  }
141 
142  SG_DEBUG("entering cross-validation run %d \n", i);
143  results[i]=evaluate_one_run();
144  SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i]);
145  }
146 
147  /* construct evaluation result */
149  result->has_conf_int=m_conf_int_alpha != 0;
151 
152  if (result->has_conf_int)
153  {
156  result->conf_int_alpha, result->conf_int_low, result->conf_int_up);
157  }
158  else
159  {
160  result->mean=CStatistics::mean(results);
161  result->conf_int_low=0;
162  result->conf_int_up=0;
163  }
164 
165  /* unlock machine if it was locked in this method */
167  {
169  m_do_unlock=false;
170  }
171 
172  SG_DEBUG("leaving %s::evaluate()\n", get_name());
173 
174  SG_REF(result);
175  return result;
176 }
177 
179 {
180  if (conf_int_alpha <0 || conf_int_alpha>= 1) {
181  SG_ERROR("%f is an illegal alpha-value for confidence interval of "
182  "cross-validation\n", conf_int_alpha);
183  }
184 
185  if (m_num_runs==1)
186  {
187  SG_WARNING("Confidence interval for Cross-Validation only possible"
188  " when number of runs is >1, ignoring.\n");
189  }
190  else
191  m_conf_int_alpha=conf_int_alpha;
192 }
193 
194 void CCrossValidation::set_num_runs(int32_t num_runs)
195 {
196  if (num_runs <1)
197  SG_ERROR("%d is an illegal number of repetitions\n", num_runs);
198 
199  m_num_runs=num_runs;
200 }
201 
203 {
204  SG_DEBUG("entering %s::evaluate_one_run()\n", get_name());
206 
207  SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets);
208 
209  /* build index sets */
211 
212  /* results array */
213  SGVector<float64_t> results(num_subsets);
214 
215  /* different behavior whether data is locked or not */
216  if (m_machine->is_data_locked())
217  {
218  SG_DEBUG("starting locked evaluation\n", get_name());
219  /* do actual cross-validation */
220  for (index_t i=0; i <num_subsets; ++i)
221  {
222  /* evtl. update xvalidation output class */
225  while (current)
226  {
227  current->update_fold_index(i);
228  SG_UNREF(current);
229  current=(CCrossValidationOutput*)
231  }
232 
233  /* index subset for training, will be freed below */
234  SGVector<index_t> inverse_subset_indices =
236 
237  /* train machine on training features */
238  m_machine->train_locked(inverse_subset_indices);
239 
240  /* feature subset for testing */
241  SGVector<index_t> subset_indices =
243 
244  /* evtl. update xvalidation output class */
246  while (current)
247  {
248  current->update_train_indices(inverse_subset_indices, "\t");
249  current->update_trained_machine(m_machine, "\t");
250  SG_UNREF(current);
251  current=(CCrossValidationOutput*)
253  }
254 
255  /* produce output for desired indices */
256  CLabels* result_labels=m_machine->apply_locked(subset_indices);
257  SG_REF(result_labels);
258 
259  /* set subset for testing labels */
260  m_labels->add_subset(subset_indices);
261 
262  /* evaluate against own labels */
263  m_evaluation_criterion->set_indices(subset_indices);
264  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
265 
266  /* evtl. update xvalidation output class */
268  while (current)
269  {
270  current->update_test_indices(subset_indices, "\t");
271  current->update_test_result(result_labels, "\t");
272  current->update_test_true_result(m_labels, "\t");
273  current->post_update_results();
274  current->update_evaluation_result(results[i], "\t");
275  SG_UNREF(current);
276  current=(CCrossValidationOutput*)
278  }
279 
280  /* remove subset to prevent side effects */
282 
283  /* clean up */
284  SG_UNREF(result_labels);
285 
286  SG_DEBUG("done locked evaluation\n", get_name());
287  }
288  }
289  else
290  {
291  SG_DEBUG("starting unlocked evaluation\n", get_name());
292  /* tell machine to store model internally
293  * (otherwise changing subset of features will kaboom the classifier) */
295 
296  /* do actual cross-validation */
297  for (index_t i=0; i <num_subsets; ++i)
298  {
299  /* evtl. update xvalidation output class */
302  while (current)
303  {
304  current->update_fold_index(i);
305  SG_UNREF(current);
306  current=(CCrossValidationOutput*)
308  }
309 
310  /* set feature subset for training */
311  SGVector<index_t> inverse_subset_indices=
313  m_features->add_subset(inverse_subset_indices);
314  for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
315  {
316  CPreprocessor* preprocessor = m_features->get_preprocessor(p);
317  preprocessor->init(m_features);
318  SG_UNREF(preprocessor);
319  }
320 
321  /* set label subset for training */
322  m_labels->add_subset(inverse_subset_indices);
323 
324  SG_DEBUG("training set %d:\n", i);
325  if (io->get_loglevel()==MSG_DEBUG)
326  {
327  SGVector<index_t>::display_vector(inverse_subset_indices.vector,
328  inverse_subset_indices.vlen, "training indices");
329  }
330 
331  /* train machine on training features and remove subset */
332  SG_DEBUG("starting training\n");
334  SG_DEBUG("finished training\n");
335 
336  /* evtl. update xvalidation output class */
338  while (current)
339  {
340  current->update_train_indices(inverse_subset_indices, "\t");
341  current->update_trained_machine(m_machine, "\t");
342  SG_UNREF(current);
343  current=(CCrossValidationOutput*)
345  }
346 
349 
350  /* set feature subset for testing (subset method that stores pointer) */
351  SGVector<index_t> subset_indices =
353  m_features->add_subset(subset_indices);
354 
355  /* set label subset for testing */
356  m_labels->add_subset(subset_indices);
357 
358  SG_DEBUG("test set %d:\n", i);
359  if (io->get_loglevel()==MSG_DEBUG)
360  {
362  subset_indices.vlen, "test indices");
363  }
364 
365  /* apply machine to test features and remove subset */
366  SG_DEBUG("starting evaluation\n");
367  SG_DEBUG("%p\n", m_features);
368  CLabels* result_labels=m_machine->apply(m_features);
369  SG_DEBUG("finished evaluation\n");
371  SG_REF(result_labels);
372 
373  /* evaluate */
374  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
375  SG_DEBUG("result on fold %d is %f\n", i, results[i]);
376 
377  /* evtl. update xvalidation output class */
379  while (current)
380  {
381  current->update_test_indices(subset_indices, "\t");
382  current->update_test_result(result_labels, "\t");
383  current->update_test_true_result(m_labels, "\t");
384  current->post_update_results();
385  current->update_evaluation_result(results[i], "\t");
386  SG_UNREF(current);
387  current=(CCrossValidationOutput*)
389  }
390 
391  /* clean up, remove subsets */
392  SG_UNREF(result_labels);
394  }
395 
396  SG_DEBUG("done unlocked evaluation\n", get_name());
397  }
398 
399  /* build arithmetic mean of results */
400  float64_t mean=CStatistics::mean(results);
401 
402  SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name());
403  return mean;
404 }
405 
407  CCrossValidationOutput* cross_validation_output)
408 {
409  m_xval_outputs->append_element(cross_validation_output);
410 }

SHOGUN Machine Learning Toolbox - Documentation