SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/machine/Machine.h>
15 #include <shogun/base/Parameter.h>
18 #include <shogun/lib/List.h>
19 
20 using namespace shogun;
21 
23 {
24  init();
25 }
26 
28  CLabels* labels, CSplittingStrategy* splitting_strategy,
29  CEvaluation* evaluation_criterion, bool autolock) :
30  CMachineEvaluation(machine, features, labels, splitting_strategy,
31  evaluation_criterion, autolock)
32 {
33  init();
34 }
35 
37  CSplittingStrategy* splitting_strategy,
38  CEvaluation* evaluation_criterion, bool autolock) :
39  CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion,
40  autolock)
41 {
42  init();
43 }
44 
46 {
48 }
49 
50 void CCrossValidation::init()
51 {
52  m_num_runs=1;
53 
54  /* do reference counting for output objects */
55  m_xval_outputs=new CList(true);
56 
57  SG_ADD(&m_num_runs, "num_runs", "Number of repetitions",
59  SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output "
60  "classes for intermediade cross-validation results",
62 }
63 
65 {
66  SG_DEBUG("entering %s::evaluate()\n", get_name())
67 
68  REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is "
69  "attached\n", get_name());
70 
71  REQUIRE(m_features, "%s::evaluate() is only possible if features are "
72  "attached\n", get_name());
73 
74  REQUIRE(m_labels, "%s::evaluate() is only possible if labels are "
75  "attached\n", get_name());
76 
77  /* if for some reason the do_unlock_frag is set, unlock */
78  if (m_do_unlock)
79  {
81  m_do_unlock=false;
82  }
83 
84  /* set labels in any case (no locking needs this) */
86 
87  if (m_autolock)
88  {
89  /* if machine supports locking try to do so */
91  {
92  /* only lock if machine is not yet locked */
93  if (!m_machine->is_data_locked())
94  {
96  m_do_unlock=true;
97  }
98  }
99  else
100  {
101  SG_WARNING("%s does not support locking. Autolocking is skipped. "
102  "Set autolock flag to false to get rid of warning.\n",
103  m_machine->get_name());
104  }
105  }
106 
108 
109  /* evtl. update xvalidation output class */
112  while (current)
113  {
114  current->init_num_runs(m_num_runs);
116  current->init_expose_labels(m_labels);
117  current->post_init();
118  SG_UNREF(current);
119  current=(CCrossValidationOutput*)
121  }
122 
123  /* perform all the x-val runs */
124  SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs)
125  for (index_t i=0; i <m_num_runs; ++i)
126  {
127 
128  /* evtl. update xvalidation output class */
130  while (current)
131  {
132  current->update_run_index(i);
133  SG_UNREF(current);
134  current=(CCrossValidationOutput*)
136  }
137 
138  SG_DEBUG("entering cross-validation run %d \n", i)
139  results[i]=evaluate_one_run();
140  SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])
141  }
142 
143  /* construct evaluation result */
145  result->mean=CStatistics::mean(results);
146  if (m_num_runs>1)
147  result->std_dev=CStatistics::std_deviation(results);
148  else
149  result->std_dev=0;
150 
151  /* unlock machine if it was locked in this method */
153  {
155  m_do_unlock=false;
156  }
157 
158  SG_DEBUG("leaving %s::evaluate()\n", get_name())
159 
160  SG_REF(result);
161  return result;
162 }
163 
164 void CCrossValidation::set_num_runs(int32_t num_runs)
165 {
166  if (num_runs <1)
167  SG_ERROR("%d is an illegal number of repetitions\n", num_runs)
168 
169  m_num_runs=num_runs;
170 }
171 
173 {
174  SG_DEBUG("entering %s::evaluate_one_run()\n", get_name())
176 
177  SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets)
178 
179  /* build index sets */
181 
182  /* results array */
183  SGVector<float64_t> results(num_subsets);
184 
185  /* different behavior whether data is locked or not */
186  if (m_machine->is_data_locked())
187  {
188  SG_DEBUG("starting locked evaluation\n", get_name())
189  /* do actual cross-validation */
190  for (index_t i=0; i <num_subsets; ++i)
191  {
192  /* evtl. update xvalidation output class */
195  while (current)
196  {
197  current->update_fold_index(i);
198  SG_UNREF(current);
199  current=(CCrossValidationOutput*)
201  }
202 
203  /* index subset for training, will be freed below */
204  SGVector<index_t> inverse_subset_indices =
206 
207  /* train machine on training features */
208  m_machine->train_locked(inverse_subset_indices);
209 
210  /* feature subset for testing */
211  SGVector<index_t> subset_indices =
213 
214  /* evtl. update xvalidation output class */
216  while (current)
217  {
218  current->update_train_indices(inverse_subset_indices, "\t");
219  current->update_trained_machine(m_machine, "\t");
220  SG_UNREF(current);
221  current=(CCrossValidationOutput*)
223  }
224 
225  /* produce output for desired indices */
226  CLabels* result_labels=m_machine->apply_locked(subset_indices);
227  SG_REF(result_labels);
228 
229  /* set subset for testing labels */
230  m_labels->add_subset(subset_indices);
231 
232  /* evaluate against own labels */
233  m_evaluation_criterion->set_indices(subset_indices);
234  results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels);
235 
236  /* evtl. update xvalidation output class */
238  while (current)
239  {
240  current->update_test_indices(subset_indices, "\t");
241  current->update_test_result(result_labels, "\t");
242  current->update_test_true_result(m_labels, "\t");
243  current->post_update_results();
244  current->update_evaluation_result(results[i], "\t");
245  SG_UNREF(current);
246  current=(CCrossValidationOutput*)
248  }
249 
250  /* remove subset to prevent side effects */
252 
253  /* clean up */
254  SG_UNREF(result_labels);
255 
256  SG_DEBUG("done locked evaluation\n", get_name())
257  }
258  }
259  else
260  {
261  SG_DEBUG("starting unlocked evaluation\n", get_name())
262  /* tell machine to store model internally
263  * (otherwise changing subset of features will kaboom the classifier) */
265 
266  /* do actual cross-validation */
267  #pragma omp parallel for
268  for (index_t i=0; i <num_subsets; ++i)
269  {
270  CMachine* machine;
271  CFeatures* features;
272  CLabels* labels;
273 
274  if (get_global_parallel()->get_num_threads()==1)
275  machine=m_machine;
276  else
277  machine=(CMachine*)m_machine->clone();
278 
279  /* evtl. update xvalidation output class */
282  #pragma omp critical
283  {
284  while (current)
285  {
286  current->update_fold_index(i);
287  SG_UNREF(current);
288  current=(CCrossValidationOutput*)
290  }
291  }
292 
293  /* set feature subset for training */
294  SGVector<index_t> inverse_subset_indices=
296 
297  if (get_global_parallel()->get_num_threads()==1)
298  features=m_features;
299  else
300  features=(CFeatures*)m_features->clone();
301 
302  features->add_subset(inverse_subset_indices);
303 
304  /* set label subset for training */
305  if (get_global_parallel()->get_num_threads()==1)
306  labels=m_labels;
307  else
308  labels=machine->get_labels();
309  labels->add_subset(inverse_subset_indices);
310 
311  SG_DEBUG("training set %d:\n", i)
312  if (io->get_loglevel()==MSG_DEBUG)
313  {
314  SGVector<index_t>::display_vector(inverse_subset_indices.vector,
315  inverse_subset_indices.vlen, "training indices");
316  }
317 
318  /* train machine on training features and remove subset */
319  SG_DEBUG("starting training\n")
320  machine->train(features);
321  SG_DEBUG("finished training\n")
322 
323  /* evtl. update xvalidation output class */
324  #pragma omp critical
325  {
327  while (current)
328  {
329  current->update_train_indices(inverse_subset_indices, "\t");
330  current->update_trained_machine(machine, "\t");
331  SG_UNREF(current);
332  current=(CCrossValidationOutput*)
334  }
335  }
336 
337  features->remove_subset();
338  labels->remove_subset();
339 
340  /* set feature subset for testing (subset method that stores pointer) */
341  SGVector<index_t> subset_indices =
343  features->add_subset(subset_indices);
344 
345  /* set label subset for testing */
346  labels->add_subset(subset_indices);
347 
348  SG_DEBUG("test set %d:\n", i)
349  if (io->get_loglevel()==MSG_DEBUG)
350  {
352  subset_indices.vlen, "test indices");
353  }
354 
355  /* apply machine to test features and remove subset */
356  SG_DEBUG("starting evaluation\n")
357  SG_DEBUG("%p\n", features)
358  CLabels* result_labels=machine->apply(features);
359  SG_DEBUG("finished evaluation\n")
360  features->remove_subset();
361  SG_REF(result_labels);
362 
363  /* evaluate */
364  #pragma omp critical
365  {
366  results[i]=m_evaluation_criterion->evaluate(result_labels, labels);
367  SG_DEBUG("result on fold %d is %f\n", i, results[i])
368  }
369 
370  /* evtl. update xvalidation output class */
371  #pragma omp critical
372  {
374  while (current)
375  {
376  current->update_test_indices(subset_indices, "\t");
377  current->update_test_result(result_labels, "\t");
378  current->update_test_true_result(labels, "\t");
379  current->post_update_results();
380  current->update_evaluation_result(results[i], "\t");
381  SG_UNREF(current);
382  current=(CCrossValidationOutput*)
384  }
385  }
386 
387  /* clean up, remove subsets */
388  labels->remove_subset();
389  if (get_global_parallel()->get_num_threads()!=1)
390  {
391  SG_UNREF(machine);
392  SG_UNREF(features);
393  SG_UNREF(labels);
394  }
395  SG_UNREF(result_labels);
396  }
397 
398  SG_DEBUG("done unlocked evaluation\n", get_name())
399  }
400 
401  /* build arithmetic mean of results */
402  float64_t mean=CStatistics::mean(results);
403 
404  SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name())
405  return mean;
406 }
407 
409  CCrossValidationOutput* cross_validation_output)
410 {
411  m_xval_outputs->append_element(cross_validation_output);
412 }
virtual void update_fold_index(index_t fold_index, const char *prefix="")
virtual void build_subsets()=0
virtual void update_train_indices(SGVector< index_t > indices, const char *prefix="")
Parallel * get_global_parallel()
Definition: SGObject.cpp:310
CSGObject * get_next_element()
Definition: List.h:185
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:187
int32_t index_t
Definition: common.h:62
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual CSGObject * clone()
Definition: SGObject.cpp:747
static float64_t std_deviation(SGVector< float64_t > values)
Definition: Statistics.cpp:120
virtual CEvaluationResult * evaluate()
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)=0
virtual void update_test_true_result(CLabels *results, const char *prefix="")
Abstract base class for all splitting types. Takes a CLabels instance and generates a desired number ...
virtual void init_num_runs(index_t num_runs, const char *prefix="")
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual void update_test_indices(SGVector< index_t > indices, const char *prefix="")
type to encapsulate the results of an evaluation run.
virtual const char * get_name() const
Definition: Machine.h:305
virtual bool train_locked(SGVector< index_t > indices)
Definition: Machine.h:239
#define SG_REF(x)
Definition: SGObject.h:54
void set_num_runs(int32_t num_runs)
A generic learning machine interface.
Definition: Machine.h:143
virtual void set_indices(SGVector< index_t > indices)
Definition: Evaluation.h:63
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:354
virtual void update_trained_machine(CMachine *machine, const char *prefix="")
index_t vlen
Definition: SGVector.h:494
CSGObject * get_first_element()
Definition: List.h:151
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:107
Class for managing individual folds in cross-validation.
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
double float64_t
Definition: common.h:50
virtual void data_unlock()
Definition: Machine.cpp:143
virtual const char * get_name() const
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:112
virtual void remove_subset()
Definition: Labels.cpp:49
Abstract class that contains the result generated by the MachineEvaluation class. ...
Machine Evaluation is an abstract class that evaluates a machine according to some criterion...
virtual CLabels * get_labels()
Definition: Machine.cpp:76
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
SGVector< index_t > generate_subset_inverse(index_t subset_idx)
static floatmax_t mean(SGVector< T > vec)
Definition: Statistics.h:42
EMessageType get_loglevel() const
Definition: SGIO.cpp:285
virtual void update_test_result(CLabels *results, const char *prefix="")
virtual bool supports_locking() const
Definition: Machine.h:293
virtual float64_t evaluate_one_run()
#define SG_UNREF(x)
Definition: SGObject.h:55
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
SGVector< index_t > generate_subset_indices(index_t subset_idx)
virtual void remove_subset()
Definition: Features.cpp:322
virtual void update_evaluation_result(float64_t result, const char *prefix="")
The class Features is the base class of all feature objects.
Definition: Features.h:68
bool append_element(CSGObject *data)
Definition: List.h:331
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:39
void add_cross_validation_output(CCrossValidationOutput *cross_validation_output)
#define SG_WARNING(...)
Definition: SGIO.h:128
#define SG_ADD(...)
Definition: SGObject.h:84
virtual void init_expose_labels(CLabels *labels)
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
bool is_data_locked() const
Definition: Machine.h:296
virtual void init_num_folds(index_t num_folds, const char *prefix="")
virtual void update_run_index(index_t run_index, const char *prefix="")
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:40
CSplittingStrategy * m_splitting_strategy
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
Class List implements a doubly connected list for low-level-objects.
Definition: List.h:84
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:152

SHOGUN Machine Learning Toolbox - Documentation