SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
evaluation
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8
*/
9
10
#include <
shogun/evaluation/CrossValidationMulticlassStorage.h
>
11
#include <
shogun/evaluation/ROCEvaluation.h
>
12
#include <
shogun/evaluation/PRCEvaluation.h
>
13
#include <
shogun/evaluation/MulticlassAccuracy.h
>
14
15
using namespace
shogun;
16
17
CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage
(
bool
compute_ROC,
bool
compute_PRC,
bool
compute_conf_matrices) :
18
CCrossValidationOutput
()
19
{
20
m_initialized
=
false
;
21
m_compute_ROC
= compute_ROC;
22
m_compute_PRC
= compute_PRC;
23
m_compute_conf_matrices
= compute_conf_matrices;
24
m_pred_labels
= NULL;
25
m_true_labels
= NULL;
26
m_num_classes
= 0;
27
m_binary_evaluations
=
new
CDynamicObjectArray
();
28
29
m_fold_ROC_graphs
=NULL;
30
m_conf_matrices
=NULL;
31
}
32
33
34
CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage
()
35
{
36
if
(
m_compute_ROC
)
37
{
38
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
39
m_fold_ROC_graphs
[i].~
SGMatrix<float64_t>
();
40
41
SG_FREE(
m_fold_ROC_graphs
);
42
}
43
44
if
(
m_compute_PRC
)
45
{
46
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
47
m_fold_PRC_graphs
[i].~
SGMatrix<float64_t>
();
48
49
SG_FREE(
m_fold_PRC_graphs
);
50
}
51
52
if
(
m_compute_conf_matrices
)
53
{
54
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
55
m_conf_matrices
[i].~
SGMatrix<int32_t>
();
56
57
SG_FREE(
m_conf_matrices
);
58
}
59
60
SG_UNREF
(
m_binary_evaluations
);
61
};
62
63
64
void
CCrossValidationMulticlassStorage::post_init
()
65
{
66
if
(
m_initialized
)
67
SG_ERROR
(
"CrossValidationMulticlassStorage was already initialized once\n"
)
68
69
if
(
m_compute_ROC
)
70
{
71
SG_DEBUG
(
"Allocating %d ROC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
)
72
m_fold_ROC_graphs
= SG_MALLOC(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
73
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
74
new
(&
m_fold_ROC_graphs
[i])
SGMatrix<float64_t>
();
75
}
76
77
if
(
m_compute_PRC
)
78
{
79
SG_DEBUG
(
"Allocating %d PRC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
)
80
m_fold_PRC_graphs
= SG_MALLOC(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
81
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
82
new
(&
m_fold_PRC_graphs
[i])
SGMatrix<float64_t>
();
83
}
84
85
if
(
m_binary_evaluations
->
get_num_elements
())
86
m_evaluations_results
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
*
m_num_classes
*
m_binary_evaluations
->
get_num_elements
());
87
88
m_accuracies
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
);
89
90
if
(
m_compute_conf_matrices
)
91
{
92
m_conf_matrices
= SG_MALLOC(
SGMatrix<int32_t>
,
m_num_folds
*
m_num_runs
);
93
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
94
new
(&
m_conf_matrices
[i])
SGMatrix<int32_t>
();
95
}
96
97
m_initialized
=
true
;
98
}
99
100
void
CCrossValidationMulticlassStorage::init_expose_labels
(
CLabels
* labels)
101
{
102
ASSERT
((
CMulticlassLabels
*)labels)
103
m_num_classes
= ((
CMulticlassLabels
*)labels)->get_num_classes();
104
}
105
106
void
CCrossValidationMulticlassStorage::post_update_results
()
107
{
108
CROCEvaluation
eval_ROC;
109
CPRCEvaluation
eval_PRC;
110
int32_t n_evals =
m_binary_evaluations
->
get_num_elements
();
111
for
(int32_t c=0; c<
m_num_classes
; c++)
112
{
113
SG_DEBUG
(
"Computing ROC for run %d fold %d class %d"
,
m_current_run_index
,
m_current_fold_index
, c)
114
CBinaryLabels
* pred_labels_binary =
m_pred_labels
->
get_binary_for_class
(c);
115
CBinaryLabels
* true_labels_binary =
m_true_labels
->
get_binary_for_class
(c);
116
if
(
m_compute_ROC
)
117
{
118
eval_ROC.
evaluate
(pred_labels_binary, true_labels_binary);
119
m_fold_ROC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
120
eval_ROC.
get_ROC
();
121
}
122
if
(
m_compute_PRC
)
123
{
124
eval_PRC.
evaluate
(pred_labels_binary, true_labels_binary);
125
m_fold_PRC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
126
eval_PRC.
get_PRC
();
127
}
128
129
for
(int32_t i=0; i<n_evals; i++)
130
{
131
CBinaryClassEvaluation
* evaluator = (
CBinaryClassEvaluation
*)
m_binary_evaluations
->
get_element_safe
(i);
132
m_evaluations_results
[
m_current_run_index
*
m_num_folds
*m_num_classes*n_evals+
m_current_fold_index
*m_num_classes*n_evals+c*n_evals+i] =
133
evaluator->
evaluate
(pred_labels_binary, true_labels_binary);
134
SG_UNREF
(evaluator);
135
}
136
137
SG_UNREF
(pred_labels_binary);
138
SG_UNREF
(true_labels_binary);
139
}
140
CMulticlassAccuracy
accuracy;
141
142
m_accuracies
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] = accuracy.
evaluate
(
m_pred_labels
,
m_true_labels
);
143
144
if
(
m_compute_conf_matrices
)
145
{
146
m_conf_matrices
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] =
CMulticlassAccuracy::get_confusion_matrix
(
m_pred_labels
,
m_true_labels
);
147
}
148
}
149
150
void
CCrossValidationMulticlassStorage::update_test_result
(
CLabels
* results,
const
char
* prefix)
151
{
152
m_pred_labels
= (
CMulticlassLabels
*)results;
153
}
154
155
void
CCrossValidationMulticlassStorage::update_test_true_result
(
CLabels
* results,
const
char
* prefix)
156
{
157
m_true_labels
= (
CMulticlassLabels
*)results;
158
}
159
SHOGUN
Machine Learning Toolbox - Documentation