SHOGUN
v2.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
evaluation
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8
*/
9
10
#include <
shogun/evaluation/CrossValidationMulticlassStorage.h
>
11
#include <
shogun/evaluation/ROCEvaluation.h
>
12
#include <
shogun/evaluation/PRCEvaluation.h
>
13
#include <
shogun/evaluation/MulticlassAccuracy.h
>
14
15
using namespace
shogun;
16
17
CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage
(
bool
compute_ROC,
bool
compute_PRC,
bool
compute_conf_matrices) :
18
CCrossValidationOutput
()
19
{
20
m_initialized
=
false
;
21
m_compute_ROC
= compute_ROC;
22
m_compute_PRC
= compute_PRC;
23
m_compute_conf_matrices
= compute_conf_matrices;
24
m_pred_labels
= NULL;
25
m_true_labels
= NULL;
26
m_num_classes
= 0;
27
m_binary_evaluations
=
new
CDynamicObjectArray
();
28
}
29
30
31
CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage
()
32
{
33
if
(
m_compute_ROC
)
34
{
35
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
36
m_fold_ROC_graphs
[i].~
SGMatrix<float64_t>
();
37
38
SG_FREE
(
m_fold_ROC_graphs
);
39
}
40
41
if
(
m_compute_PRC
)
42
{
43
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
44
m_fold_PRC_graphs
[i].~
SGMatrix<float64_t>
();
45
46
SG_FREE
(
m_fold_PRC_graphs
);
47
}
48
49
if
(
m_compute_conf_matrices
)
50
{
51
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
52
m_conf_matrices
[i].~
SGMatrix<int32_t>
();
53
54
SG_FREE
(
m_conf_matrices
);
55
}
56
57
SG_UNREF
(
m_binary_evaluations
);
58
};
59
60
61
void
CCrossValidationMulticlassStorage::post_init
()
62
{
63
if
(
m_initialized
)
64
SG_ERROR
(
"CrossValidationMulticlassStorage was already initialized once\n"
);
65
66
if
(
m_compute_ROC
)
67
{
68
SG_DEBUG
(
"Allocating %d ROC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
);
69
m_fold_ROC_graphs
=
SG_MALLOC
(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
70
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
71
new
(&
m_fold_ROC_graphs
[i])
SGMatrix<float64_t>
();
72
}
73
74
if
(
m_compute_PRC
)
75
{
76
SG_DEBUG
(
"Allocating %d PRC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
);
77
m_fold_PRC_graphs
=
SG_MALLOC
(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
78
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
79
new
(&
m_fold_PRC_graphs
[i])
SGMatrix<float64_t>
();
80
}
81
82
if
(
m_binary_evaluations
->
get_num_elements
())
83
m_evaluations_results
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
*
m_num_classes
*
m_binary_evaluations
->
get_num_elements
());
84
85
m_accuracies
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
);
86
87
if
(
m_compute_conf_matrices
)
88
{
89
m_conf_matrices
=
SG_MALLOC
(
SGMatrix<int32_t>
,
m_num_folds
*
m_num_runs
);
90
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
91
new
(&
m_conf_matrices
[i])
SGMatrix<int32_t>
();
92
}
93
94
m_initialized
=
true
;
95
}
96
97
void
CCrossValidationMulticlassStorage::init_expose_labels
(
CLabels
* labels)
98
{
99
ASSERT
((
CMulticlassLabels
*)labels);
100
m_num_classes
= ((
CMulticlassLabels
*)labels)->get_num_classes();
101
}
102
103
void
CCrossValidationMulticlassStorage::post_update_results
()
104
{
105
CROCEvaluation
eval_ROC;
106
CPRCEvaluation
eval_PRC;
107
int32_t n_evals =
m_binary_evaluations
->
get_num_elements
();
108
for
(int32_t c=0; c<
m_num_classes
; c++)
109
{
110
SG_DEBUG
(
"Computing ROC for run %d fold %d class %d"
,
m_current_run_index
,
m_current_fold_index
, c);
111
CBinaryLabels
* pred_labels_binary =
m_pred_labels
->
get_binary_for_class
(c);
112
CBinaryLabels
* true_labels_binary =
m_true_labels
->
get_binary_for_class
(c);
113
if
(
m_compute_ROC
)
114
{
115
eval_ROC.
evaluate
(pred_labels_binary, true_labels_binary);
116
m_fold_ROC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
117
eval_ROC.
get_ROC
();
118
}
119
if
(
m_compute_PRC
)
120
{
121
eval_PRC.
evaluate
(pred_labels_binary, true_labels_binary);
122
m_fold_PRC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
123
eval_PRC.
get_PRC
();
124
}
125
126
for
(int32_t i=0; i<n_evals; i++)
127
{
128
CBinaryClassEvaluation
* evaluator = (
CBinaryClassEvaluation
*)
m_binary_evaluations
->
get_element_safe
(i);
129
m_evaluations_results
[
m_current_run_index
*
m_num_folds
*m_num_classes*n_evals+
m_current_fold_index
*m_num_classes*n_evals+c*n_evals+i] =
130
evaluator->
evaluate
(pred_labels_binary, true_labels_binary);
131
SG_UNREF
(evaluator);
132
}
133
134
SG_UNREF
(pred_labels_binary);
135
SG_UNREF
(true_labels_binary);
136
}
137
CMulticlassAccuracy
accuracy;
138
139
m_accuracies
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] = accuracy.
evaluate
(
m_pred_labels
,
m_true_labels
);
140
141
if
(
m_compute_conf_matrices
)
142
{
143
m_conf_matrices
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] =
CMulticlassAccuracy::get_confusion_matrix
(
m_pred_labels
,
m_true_labels
);
144
}
145
}
146
147
void
CCrossValidationMulticlassStorage::update_test_result
(
CLabels
* results,
const
char
* prefix)
148
{
149
m_pred_labels
= (
CMulticlassLabels
*)results;
150
}
151
152
void
CCrossValidationMulticlassStorage::update_test_true_result
(
CLabels
* results,
const
char
* prefix)
153
{
154
m_true_labels
= (
CMulticlassLabels
*)results;
155
}
156
SHOGUN
Machine Learning Toolbox - Documentation