SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8  */
9 
14 
15 using namespace shogun;
16 
17 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) :
19 {
20  m_initialized = false;
21  m_compute_ROC = compute_ROC;
22  m_compute_PRC = compute_PRC;
23  m_compute_conf_matrices = compute_conf_matrices;
24  m_pred_labels = NULL;
25  m_true_labels = NULL;
26  m_num_classes = 0;
28 }
29 
30 
32 {
33  if (m_compute_ROC)
34  {
35  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
37 
39  }
40 
41  if (m_compute_PRC)
42  {
43  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
45 
47  }
48 
50  {
51  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
53 
55  }
56 
58 };
59 
60 
62 {
63  if (m_initialized)
64  SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n");
65 
66  if (m_compute_ROC)
67  {
68  SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes);
70  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
72  }
73 
74  if (m_compute_PRC)
75  {
76  SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes);
78  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
80  }
81 
84 
86 
88  {
90  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
92  }
93 
94  m_initialized = true;
95 }
96 
98 {
99  ASSERT((CMulticlassLabels*)labels);
100  m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
101 }
102 
104 {
105  CROCEvaluation eval_ROC;
106  CPRCEvaluation eval_PRC;
107  int32_t n_evals = m_binary_evaluations->get_num_elements();
108  for (int32_t c=0; c<m_num_classes; c++)
109  {
110  SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
111  CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
112  CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
113  if (m_compute_ROC)
114  {
115  eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
117  eval_ROC.get_ROC();
118  }
119  if (m_compute_PRC)
120  {
121  eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
123  eval_PRC.get_PRC();
124  }
125 
126  for (int32_t i=0; i<n_evals; i++)
127  {
129  m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
130  evaluator->evaluate(pred_labels_binary, true_labels_binary);
131  SG_UNREF(evaluator);
132  }
133 
134  SG_UNREF(pred_labels_binary);
135  SG_UNREF(true_labels_binary);
136  }
137  CMulticlassAccuracy accuracy;
138 
140 
142  {
144  }
145 }
146 
148 {
149  m_pred_labels = (CMulticlassLabels*)results;
150 }
151 
153 {
154  m_true_labels = (CMulticlassLabels*)results;
155 }
156 

SHOGUN Machine Learning Toolbox - Documentation