SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CrossValidationMulticlassStorage.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn
8  *
9  */
10 
11 #ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_
12 #define CROSSVALIDATIONMULTICLASSSTORAGE_H_
13 
17 #include <shogun/lib/SGMatrix.h>
19 
20 namespace shogun
21 {
22 
23 class CMachine;
24 class CLabels;
25 class CEvaluation;
26 
32 {
33 public:
34 
40  CCrossValidationMulticlassStorage(bool compute_ROC=true, bool compute_PRC=false, bool compute_conf_matrices=false);
41 
44 
52  SGMatrix<float64_t> get_fold_ROC(int32_t run, int32_t fold, int32_t c)
53  {
54  ASSERT(0<=run);
55  ASSERT(run<m_num_runs);
56  ASSERT(0<=fold);
57  ASSERT(fold<m_num_folds);
58  ASSERT(0<=c);
60  REQUIRE(m_compute_ROC, "ROC computation was not enabled\n");
62  }
63 
71  SGMatrix<float64_t> get_fold_PRC(int32_t run, int32_t fold, int32_t c)
72  {
73  ASSERT(0<=run);
74  ASSERT(run<m_num_runs);
75  ASSERT(0<=fold);
76  ASSERT(fold<m_num_folds);
77  ASSERT(0<=c);
79  REQUIRE(m_compute_PRC, "PRC computation was not enabled\n");
81  }
82 
88  {
89  m_binary_evaluations->push_back(evaluation);
90  }
91 
97  {
99  }
100 
108  float64_t get_fold_evaluation_result(int32_t run, int32_t fold, int32_t c, int32_t e)
109  {
110  ASSERT(0<=run);
111  ASSERT(run<m_num_runs);
112  ASSERT(0<=fold);
113  ASSERT(fold<m_num_folds);
114  ASSERT(0<=c);
116  ASSERT(0<=e);
117  int32_t n_evals = m_binary_evaluations->get_num_elements();
118  ASSERT(e<n_evals);
119  return m_evaluations_results[run*m_num_folds*m_num_classes*n_evals+fold*m_num_classes*n_evals+c*n_evals+e];
120  }
121 
126  float64_t get_fold_accuracy(int32_t run, int32_t fold)
127  {
128  ASSERT(0<=run);
129  ASSERT(run<m_num_runs);
130  ASSERT(0<=fold);
131  ASSERT(fold<m_num_folds);
132  return m_accuracies[run*m_num_folds+fold];
133  }
134 
139  SGMatrix<int32_t> get_fold_conf_matrix(int32_t run, int32_t fold)
140  {
141  ASSERT(0<=run);
142  ASSERT(run<m_num_runs);
143  ASSERT(0<=fold);
144  ASSERT(fold<m_num_folds);
145  REQUIRE(m_compute_conf_matrices, "Confusion matrices computation was not enabled\n");
146  return m_conf_matrices[run*m_num_folds+fold];
147  }
148 
150  virtual void post_init();
151 
153  virtual void post_update_results();
154 
158  virtual void init_expose_labels(CLabels* labels);
159 
165  virtual void update_test_result(CLabels* results,
166  const char* prefix="");
167 
173  virtual void update_test_true_result(CLabels* results,
174  const char* prefix="");
175 
177  virtual const char* get_name() const { return "CrossValidationMulticlassStorage"; }
178 
179 protected:
180 
183 
186 
189 
192 
195 
198 
201 
204 
207 
210 
213 
216 
218  int32_t m_num_classes;
219 
220 };
221 
222 }
223 
224 #endif /* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */

SHOGUN Machine Learning Toolbox - Documentation