SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8  */
9 
14 
15 using namespace shogun;
16 
17 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) :
19 {
20  m_initialized = false;
21  m_compute_ROC = compute_ROC;
22  m_compute_PRC = compute_PRC;
23  m_compute_conf_matrices = compute_conf_matrices;
24  m_pred_labels = NULL;
25  m_true_labels = NULL;
26  m_num_classes = 0;
28 
29  m_fold_ROC_graphs=NULL;
30  m_conf_matrices=NULL;
31 }
32 
33 
35 {
37  {
38  SG_FREE(m_fold_ROC_graphs);
39  }
40 
42  {
43  SG_FREE(m_fold_PRC_graphs);
44  }
45 
47  {
48  SG_FREE(m_conf_matrices);
49  }
50 
52  {
54  }
55 };
56 
57 
59 {
60  if (m_initialized)
61  SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n")
62 
63  if (m_compute_ROC)
64  {
65  SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes)
66  m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
67  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
69  }
70 
71  if (m_compute_PRC)
72  {
73  SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes)
74  m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
75  for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
77  }
78 
81 
83 
85  {
87  for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
89  }
90 
91  m_initialized = true;
92 }
93 
95 {
96  ASSERT((CMulticlassLabels*)labels)
97  m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
98 }
99 
101 {
102  CROCEvaluation eval_ROC;
103  CPRCEvaluation eval_PRC;
104  int32_t n_evals = m_binary_evaluations->get_num_elements();
105  for (int32_t c=0; c<m_num_classes; c++)
106  {
107  SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c)
108  CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
109  CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
110  if (m_compute_ROC)
111  {
112  eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
114  eval_ROC.get_ROC();
115  }
116  if (m_compute_PRC)
117  {
118  eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
120  eval_PRC.get_PRC();
121  }
122 
123  for (int32_t i=0; i<n_evals; i++)
124  {
126  m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
127  evaluator->evaluate(pred_labels_binary, true_labels_binary);
128  SG_UNREF(evaluator);
129  }
130 
131  SG_UNREF(pred_labels_binary);
132  SG_UNREF(true_labels_binary);
133  }
134  CMulticlassAccuracy accuracy;
135 
137 
139  {
141  }
142 }
143 
145 {
146  m_pred_labels = (CMulticlassLabels*)results;
147 }
148 
150 {
151  m_true_labels = (CMulticlassLabels*)results;
152 }
153 

SHOGUN Machine Learning Toolbox - Documentation