SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
18 #include <shogun/lib/List.h>
19 
20 using namespace shogun;
21 
23 {
24  init();
25 }
26 
28  CLabels* labels, CSplittingStrategy* splitting_strategy,
29  CEvaluation* evaluation_criterion, bool autolock) :
30  CMachineEvaluation(machine, features, labels, splitting_strategy,
31  evaluation_criterion, autolock)
32 {
33  init();
34 }
35 
37  CSplittingStrategy* splitting_strategy,
38  CEvaluation* evaluation_criterion, bool autolock) :
39  CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
40  autolock)
41 {
42  init();
43 }
44 
46 {
48 }
49 
50 void CCrossValidation::init()
51 {
52  m_num_runs=1;
54 
55  /* do reference counting for output objects */
56  m_xval_outputs=new CList(true);
57 
58  SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
60  SG_ADD(&m_conf_int_alpha, "conf_int_alpha", "alpha-value "
61  "of confidence interval", MS_NOT_AVAILABLE);
62  SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
63  "classes for intermediade cross-validation results",
65 }
66 
68 {
69  SG_DEBUG("entering %s::evaluate()\n", get_name())
70 
71  REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
72  "attached\n", get_name());
73 
74  REQUIRE(m_features, "%s::evaluate() is only possible if features are "
75  "attached\n", get_name());
76 
77  REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
78  "attached\n", get_name());
79 
80  /* if for some reason the do_unlock_frag is set, unlock */
81  if (m_do_unlock)
82  {
84  m_do_unlock=false;
85  }
86 
87  /* set labels in any case (no locking needs this) */
89 
90  if (m_autolock)
91  {
92  /* if machine supports locking try to do so */
94  {
95  /* only lock if machine is not yet locked */
96  if (!m_machine->is_data_locked())
97  {
99  m_do_unlock=true;
100  }
101  }
102  else
103  {
104  SG_WARNING("%s does not support locking. Autolocking is skipped. "
105  "Set autolock flag to false to get rid of warning.\n",
106  m_machine->get_name());
107  }
108  }
109 
111 
112  /* evtl. update xvalidation output class */
115  while (current)
116  {
117  current->init_num_runs(m_num_runs);
119  current->init_expose_labels(m_labels);
120  current->post_init();
121  SG_UNREF(current);
122  current=(CCrossValidationOutput*)
124  }
125 
126  /* perform all the x-val runs */
127  SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs)
128  for (index_t i=0; i <m_num_runs; ++i)
129  {
130 
131  /* evtl. update xvalidation output class */
133  while (current)
134  {
135  current->update_run_index(i);
136  SG_UNREF(current);
137  current=(CCrossValidationOutput*)
139  }
140 
141  SG_DEBUG("entering cross-validation run %d \n", i)
142  results[i]=evaluate_one_run();
143  SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])
144  }
145 
146  /* construct evaluation result */
148  result->has_conf_int=m_conf_int_alpha != 0;
150 
151  if (result->has_conf_int)
152  {
155  result->conf_int_alpha, result->conf_int_low, result->conf_int_up);
156  }
157  else
158  {
159  result->mean=CStatistics::mean(results);
160  result->conf_int_low=0;
161  result->conf_int_up=0;
162  }
163 
164  /* unlock machine if it was locked in this method */
166  {
168  m_do_unlock=false;
169  }
170 
171  SG_DEBUG("leaving %s::evaluate()\n", get_name())
172 
173  SG_REF(result);
174  return result;
175 }
176 
178 {
179  if (conf_int_alpha <0 || conf_int_alpha>= 1) {
180  SG_ERROR("%f is an illegal alpha-value for confidence interval of "
181  "cross-validation\n", conf_int_alpha);
182  }
183 
184  if (m_num_runs==1)
185  {
186  SG_WARNING("Confidence interval for Cross-Validation only possible"
187  " when number of runs is >1, ignoring.\n");
188  }
189  else
190  m_conf_int_alpha=conf_int_alpha;
191 }
192 
193 void CCrossValidation::set_num_runs(int32_t num_runs)
194 {
195  if (num_runs <1)
196  SG_ERROR("%d is an illegal number of repetitions\n", num_runs)
197 
198  m_num_runs=num_runs;
199 }
200 
202 {
203  SG_DEBUG("entering %s::evaluate_one_run()\n", get_name())
205 
206  SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets)
207 
208  /* build index sets */
210 
211  /* results array */
212  SGVector<float64_t> results(num_subsets);
213 
214  /* different behavior whether data is locked or not */
215  if (m_machine->is_data_locked())
216  {
217  SG_DEBUG("starting locked evaluation\n", get_name())
218  /* do actual cross-validation */
219  for (index_t i=0; i <num_subsets; ++i)
220  {
221  /* evtl. update xvalidation output class */
224  while (current)
225  {
226  current->update_fold_index(i);
227  SG_UNREF(current);
228  current=(CCrossValidationOutput*)
230  }
231 
232  /* index subset for training, will be freed below */
233  SGVector<index_t> inverse_subset_indices =
235 
236  /* train machine on training features */
237  m_machine->train_locked(inverse_subset_indices);
238 
239  /* feature subset for testing */
240  SGVector<index_t> subset_indices =
242 
243  /* evtl. update xvalidation output class */
245  while (current)
246  {
247  current->update_train_indices(inverse_subset_indices, "\t");
248  current->update_trained_machine(m_machine, "\t");
249  SG_UNREF(current);
250  current=(CCrossValidationOutput*)
252  }
253 
254  /* produce output for desired indices */
255  CLabels* result_labels=m_machine->apply_locked(subset_indices);
256  SG_REF(result_labels);
257 
258  /* set subset for testing labels */
259  m_labels->add_subset(subset_indices);
260 
261  /* evaluate against own labels */
262  m_evaluation_criterion->set_indices(subset_indices);
263  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
264 
265  /* evtl. update xvalidation output class */
267  while (current)
268  {
269  current->update_test_indices(subset_indices, "\t");
270  current->update_test_result(result_labels, "\t");
271  current->update_test_true_result(m_labels, "\t");
272  current->post_update_results();
273  current->update_evaluation_result(results[i], "\t");
274  SG_UNREF(current);
275  current=(CCrossValidationOutput*)
277  }
278 
279  /* remove subset to prevent side effects */
281 
282  /* clean up */
283  SG_UNREF(result_labels);
284 
285  SG_DEBUG("done locked evaluation\n", get_name())
286  }
287  }
288  else
289  {
290  SG_DEBUG("starting unlocked evaluation\n", get_name())
291  /* tell machine to store model internally
292  * (otherwise changing subset of features will kaboom the classifier) */
294 
295  /* do actual cross-validation */
296  for (index_t i=0; i <num_subsets; ++i)
297  {
298  /* evtl. update xvalidation output class */
301  while (current)
302  {
303  current->update_fold_index(i);
304  SG_UNREF(current);
305  current=(CCrossValidationOutput*)
307  }
308 
309  /* set feature subset for training */
310  SGVector<index_t> inverse_subset_indices=
312  m_features->add_subset(inverse_subset_indices);
313  for (index_t p=0; p<m_features->get_num_preprocessors(); p++)
314  {
315  CPreprocessor* preprocessor = m_features->get_preprocessor(p);
316  preprocessor->init(m_features);
317  SG_UNREF(preprocessor);
318  }
319 
320  /* set label subset for training */
321  m_labels->add_subset(inverse_subset_indices);
322 
323  SG_DEBUG("training set %d:\n", i)
324  if (io->get_loglevel()==MSG_DEBUG)
325  {
326  SGVector<index_t>::display_vector(inverse_subset_indices.vector,
327  inverse_subset_indices.vlen, "training indices");
328  }
329 
330  /* train machine on training features and remove subset */
331  SG_DEBUG("starting training\n")
333  SG_DEBUG("finished training\n")
334 
335  /* evtl. update xvalidation output class */
337  while (current)
338  {
339  current->update_train_indices(inverse_subset_indices, "\t");
340  current->update_trained_machine(m_machine, "\t");
341  SG_UNREF(current);
342  current=(CCrossValidationOutput*)
344  }
345 
348 
349  /* set feature subset for testing (subset method that stores pointer) */
350  SGVector<index_t> subset_indices =
352  m_features->add_subset(subset_indices);
353 
354  /* set label subset for testing */
355  m_labels->add_subset(subset_indices);
356 
357  SG_DEBUG("test set %d:\n", i)
358  if (io->get_loglevel()==MSG_DEBUG)
359  {
361  subset_indices.vlen, "test indices");
362  }
363 
364  /* apply machine to test features and remove subset */
365  SG_DEBUG("starting evaluation\n")
366  SG_DEBUG("%p\n", m_features)
367  CLabels* result_labels=m_machine->apply(m_features);
368  SG_DEBUG("finished evaluation\n")
370  SG_REF(result_labels);
371 
372  /* evaluate */
373  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
374  SG_DEBUG("result on fold %d is %f\n", i, results[i])
375 
376  /* evtl. update xvalidation output class */
378  while (current)
379  {
380  current->update_test_indices(subset_indices, "\t");
381  current->update_test_result(result_labels, "\t");
382  current->update_test_true_result(m_labels, "\t");
383  current->post_update_results();
384  current->update_evaluation_result(results[i], "\t");
385  SG_UNREF(current);
386  current=(CCrossValidationOutput*)
388  }
389 
390  /* clean up, remove subsets */
391  SG_UNREF(result_labels);
393  }
394 
395  SG_DEBUG("done unlocked evaluation\n", get_name())
396  }
397 
398  /* build arithmetic mean of results */
399  float64_t mean=CStatistics::mean(results);
400 
401  SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name())
402  return mean;
403 }
404 
406  CCrossValidationOutput* cross_validation_output)
407 {
408  m_xval_outputs->append_element(cross_validation_output);
409 }
virtual void update_fold_index(index_t fold_index, const char *prefix="")
virtual void build_subsets()=0
virtual void update_train_indices(SGVector< index_t > indices, const char *prefix="")
virtual bool init(CFeatures *features)=0
CSGObject * get_next_element()
Definition: List.h:185
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:187
int32_t index_t
Definition: common.h:62
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
void set_conf_int_alpha(float64_t m_conf_int_alpha)
static float64_t confidence_intervals_mean(SGVector< float64_t > values, float64_t alpha, float64_t &conf_int_low, float64_t &conf_int_up)
Definition: Statistics.cpp:335
virtual CEvaluationResult * evaluate()
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
virtual void update_test_true_result(CLabels *results, const char *prefix="")
Abstract base class for all splitting types. Takes a CLabels instance and generates a desired number ...
virtual void init_num_runs(index_t num_runs, const char *prefix="")
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual void update_test_indices(SGVector< index_t > indices, const char *prefix="")
CPreprocessor * get_preprocessor(int32_t num) const
Definition: Features.cpp:93
type to encapsulate the results of an evaluation run. May contain confidence interval (if conf_int_al...
virtual const char * get_name() const
Definition: Machine.h:305
virtual bool train_locked(SGVector< index_t > indices)
Definition: Machine.h:239
#define SG_REF(x)
Definition: SGObject.h:51
void set_num_runs(int32_t num_runs)
A generic learning machine interface.
Definition: Machine.h:143
virtual void set_indices(SGVector< index_t > indices)
Definition: Evaluation.h:63
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:356
int32_t get_num_preprocessors() const
Definition: Features.cpp:155
virtual void update_trained_machine(CMachine *machine, const char *prefix="")
index_t vlen
Definition: SGVector.h:494
CSGObject * get_first_element()
Definition: List.h:151
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:107
Class for managing individual folds in cross-validation.
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
double float64_t
Definition: common.h:50
virtual void data_unlock()
Definition: Machine.cpp:143
virtual const char * get_name() const
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:112
virtual void remove_subset()
Definition: Labels.cpp:49
Abstract class that contains the result generated by the MachineEvaluation class. ...
Machine Evaluation is an abstract class that evaluates a machine according to some criterion...
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
SGVector< index_t > generate_subset_inverse(index_t subset_idx)
static floatmax_t mean(SGVector< T > vec)
Definition: Statistics.h:44
EMessageType get_loglevel() const
Definition: SGIO.cpp:285
virtual void update_test_result(CLabels *results, const char *prefix="")
virtual bool supports_locking() const
Definition: Machine.h:293
virtual float64_t evaluate_one_run()
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
SGVector< index_t > generate_subset_indices(index_t subset_idx)
virtual void remove_subset()
Definition: Features.cpp:322
virtual void update_evaluation_result(float64_t result, const char *prefix="")
The class Features is the base class of all feature objects.
Definition: Features.h:68
bool append_element(CSGObject *data)
Definition: List.h:331
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:39
Class Preprocessor defines a preprocessor interface.
Definition: Preprocessor.h:75
void add_cross_validation_output(CCrossValidationOutput *cross_validation_output)
#define SG_WARNING(...)
Definition: SGIO.h:128
#define SG_ADD(...)
Definition: SGObject.h:81
virtual void init_expose_labels(CLabels *labels)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
bool is_data_locked() const
Definition: Machine.h:296
virtual void init_num_folds(index_t num_folds, const char *prefix="")
virtual void update_run_index(index_t run_index, const char *prefix="")
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:40
CSplittingStrategy * m_splitting_strategy
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
Class List implements a doubly connected list for low-level-objects.
Definition: List.h:84
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:152

SHOGUN Machine Learning Toolbox - Documentation