SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CrossValidationMulticlassStorage.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn
8  *
9  */
10 
11 #ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_
12 #define CROSSVALIDATIONMULTICLASSSTORAGE_H_
13 
14 #include <shogun/lib/config.h>
15 
19 #include <shogun/lib/SGMatrix.h>
21 
22 namespace shogun
23 {
24 
25 class CMachine;
26 class CLabels;
27 class CEvaluation;
28 
34 {
35 public:
36 
42  CCrossValidationMulticlassStorage(bool compute_ROC=true, bool compute_PRC=false, bool compute_conf_matrices=false);
43 
46 
54  SGMatrix<float64_t> get_fold_ROC(int32_t run, int32_t fold, int32_t c)
55  {
56  ASSERT(0<=run)
57  ASSERT(run<m_num_runs)
58  ASSERT(0<=fold)
59  ASSERT(fold<m_num_folds)
60  ASSERT(0<=c)
62  REQUIRE(m_compute_ROC, "ROC computation was not enabled\n")
64  }
65 
73  SGMatrix<float64_t> get_fold_PRC(int32_t run, int32_t fold, int32_t c)
74  {
75  ASSERT(0<=run)
76  ASSERT(run<m_num_runs)
77  ASSERT(0<=fold)
78  ASSERT(fold<m_num_folds)
79  ASSERT(0<=c)
81  REQUIRE(m_compute_PRC, "PRC computation was not enabled\n")
83  }
84 
90  {
91  m_binary_evaluations->push_back(evaluation);
92  }
93 
99  {
101  }
102 
110  float64_t get_fold_evaluation_result(int32_t run, int32_t fold, int32_t c, int32_t e)
111  {
112  ASSERT(0<=run)
113  ASSERT(run<m_num_runs)
114  ASSERT(0<=fold)
115  ASSERT(fold<m_num_folds)
116  ASSERT(0<=c)
118  ASSERT(0<=e)
119  int32_t n_evals = m_binary_evaluations->get_num_elements();
120  ASSERT(e<n_evals)
121  return m_evaluations_results[run*m_num_folds*m_num_classes*n_evals+fold*m_num_classes*n_evals+c*n_evals+e];
122  }
123 
128  float64_t get_fold_accuracy(int32_t run, int32_t fold)
129  {
130  ASSERT(0<=run)
131  ASSERT(run<m_num_runs)
132  ASSERT(0<=fold)
133  ASSERT(fold<m_num_folds)
134  return m_accuracies[run*m_num_folds+fold];
135  }
136 
141  SGMatrix<int32_t> get_fold_conf_matrix(int32_t run, int32_t fold)
142  {
143  ASSERT(0<=run)
144  ASSERT(run<m_num_runs)
145  ASSERT(0<=fold)
146  ASSERT(fold<m_num_folds)
147  REQUIRE(m_compute_conf_matrices, "Confusion matrices computation was not enabled\n")
148  return m_conf_matrices[run*m_num_folds+fold];
149  }
150 
152  virtual void post_init();
153 
155  virtual void post_update_results();
156 
160  virtual void init_expose_labels(CLabels* labels);
161 
167  virtual void update_test_result(CLabels* results,
168  const char* prefix="");
169 
175  virtual void update_test_true_result(CLabels* results,
176  const char* prefix="");
177 
179  virtual const char* get_name() const { return "CrossValidationMulticlassStorage"; }
180 
181 protected:
182 
185 
188 
191 
194 
197 
200 
203 
206 
209 
212 
215 
218 
220  int32_t m_num_classes;
221 
222 };
223 
224 }
225 
226 #endif /* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */
void append_binary_evaluation(CBinaryClassEvaluation *evaluation)
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
#define REQUIRE(x,...)
Definition: SGIO.h:206
float64_t get_fold_accuracy(int32_t run, int32_t fold)
Multiclass Labels for multi-class classification.
SGMatrix< int32_t > get_fold_conf_matrix(int32_t run, int32_t fold)
float64_t get_fold_evaluation_result(int32_t run, int32_t fold, int32_t c, int32_t e)
virtual void update_test_true_result(CLabels *results, const char *prefix="")
Class for storing multiclass evaluation information in every fold of cross-validation.
Class for managing individual folds in cross-validation.
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
CCrossValidationMulticlassStorage(bool compute_ROC=true, bool compute_PRC=false, bool compute_conf_matrices=false)
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
CBinaryClassEvaluation * get_binary_evaluation(int32_t idx)
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
CSGObject * get_element_safe(int32_t index) const
SGMatrix< float64_t > get_fold_PRC(int32_t run, int32_t fold, int32_t c)
SGMatrix< float64_t > get_fold_ROC(int32_t run, int32_t fold, int32_t c)
The class TwoClassEvaluation, a base class used to evaluate binary classification labels...
virtual void update_test_result(CLabels *results, const char *prefix="")

SHOGUN Machine Learning Toolbox - Documentation