SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8  */
9 
14 
15 using namespace shogun;
16 
17 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) :
19 {
20  m_initialized = false;
21  m_compute_ROC = compute_ROC;
22  m_compute_PRC = compute_PRC;
23  m_compute_conf_matrices = compute_conf_matrices;
24  m_pred_labels = NULL;
25  m_true_labels = NULL;
26  m_num_classes = 0;
28 
29  m_fold_ROC_graphs=NULL;
30  m_conf_matrices=NULL;
31 }
32 
33 
35 {
36  if (m_compute_ROC)
37  {
38  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
40 
41  SG_FREE(m_fold_ROC_graphs);
42  }
43 
44  if (m_compute_PRC)
45  {
46  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
48 
49  SG_FREE(m_fold_PRC_graphs);
50  }
51 
53  {
54  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
56 
57  SG_FREE(m_conf_matrices);
58  }
59 
61 };
62 
63 
65 {
66  if (m_initialized)
67  SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n")
68 
69  if (m_compute_ROC)
70  {
71  SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes)
72  m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
73  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
75  }
76 
77  if (m_compute_PRC)
78  {
79  SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes)
80  m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
81  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
83  }
84 
87 
89 
91  {
93  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
95  }
96 
97  m_initialized = true;
98 }
99 
101 {
102  ASSERT((CMulticlassLabels*)labels)
103  m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
104 }
105 
107 {
108  CROCEvaluation eval_ROC;
109  CPRCEvaluation eval_PRC;
110  int32_t n_evals = m_binary_evaluations->get_num_elements();
111  for (int32_t c=0; c<m_num_classes; c++)
112  {
113  SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c)
114  CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
115  CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
116  if (m_compute_ROC)
117  {
118  eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
120  eval_ROC.get_ROC();
121  }
122  if (m_compute_PRC)
123  {
124  eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
126  eval_PRC.get_PRC();
127  }
128 
129  for (int32_t i=0; i<n_evals; i++)
130  {
132  m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
133  evaluator->evaluate(pred_labels_binary, true_labels_binary);
134  SG_UNREF(evaluator);
135  }
136 
137  SG_UNREF(pred_labels_binary);
138  SG_UNREF(true_labels_binary);
139  }
140  CMulticlassAccuracy accuracy;
141 
143 
145  {
147  }
148 }
149 
151 {
152  m_pred_labels = (CMulticlassLabels*)results;
153 }
154 
156 {
157  m_true_labels = (CMulticlassLabels*)results;
158 }
159 

SHOGUN Machine Learning Toolbox - Documentation