SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
evaluation
CrossValidationMulticlassStorage.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn
8
*
9
*/
10
11
#ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_
12
#define CROSSVALIDATIONMULTICLASSSTORAGE_H_
13
14
#include <
shogun/evaluation/CrossValidationOutput.h
>
15
#include <
shogun/evaluation/BinaryClassEvaluation.h
>
16
#include <
shogun/labels/MulticlassLabels.h
>
17
#include <
shogun/lib/SGMatrix.h
>
18
#include <
shogun/lib/DynamicObjectArray.h
>
19
20
namespace
shogun
21
{
22
23
class
CMachine;
24
class
CLabels;
25
class
CEvaluation;
26
31
class
CCrossValidationMulticlassStorage
:
public
CCrossValidationOutput
32
{
33
public
:
34
40
CCrossValidationMulticlassStorage
(
bool
compute_ROC=
true
,
bool
compute_PRC=
false
,
bool
compute_conf_matrices=
false
);
41
43
virtual
~CCrossValidationMulticlassStorage
();
44
52
SGMatrix<float64_t>
get_fold_ROC
(int32_t run, int32_t fold, int32_t c)
53
{
54
ASSERT
(0<=run)
55
ASSERT
(run<
m_num_runs
)
56
ASSERT
(0<=fold)
57
ASSERT
(fold<
m_num_folds
)
58
ASSERT
(0<=c)
59
ASSERT
(c<
m_num_classes
)
60
REQUIRE
(
m_compute_ROC
,
"ROC computation was not enabled\n"
)
61
return
m_fold_ROC_graphs
[run*
m_num_folds
*
m_num_classes
+fold*
m_num_classes
+c];
62
}
63
71
SGMatrix<float64_t>
get_fold_PRC
(int32_t run, int32_t fold, int32_t c)
72
{
73
ASSERT
(0<=run)
74
ASSERT
(run<
m_num_runs
)
75
ASSERT
(0<=fold)
76
ASSERT
(fold<
m_num_folds
)
77
ASSERT
(0<=c)
78
ASSERT
(c<
m_num_classes
)
79
REQUIRE
(
m_compute_PRC
,
"PRC computation was not enabled\n"
)
80
return
m_fold_PRC_graphs
[run*
m_num_folds
*
m_num_classes
+fold*
m_num_classes
+c];
81
}
82
87
void
append_binary_evaluation
(
CBinaryClassEvaluation
* evaluation)
88
{
89
m_binary_evaluations
->
push_back
(evaluation);
90
}
91
96
CBinaryClassEvaluation
*
get_binary_evaluation
(int32_t idx)
97
{
98
return
(
CBinaryClassEvaluation
*)
m_binary_evaluations
->
get_element_safe
(idx);
99
}
100
108
float64_t
get_fold_evaluation_result
(int32_t run, int32_t fold, int32_t c, int32_t e)
109
{
110
ASSERT
(0<=run)
111
ASSERT
(run<
m_num_runs
)
112
ASSERT
(0<=fold)
113
ASSERT
(fold<
m_num_folds
)
114
ASSERT
(0<=c)
115
ASSERT
(c<
m_num_classes
)
116
ASSERT
(0<=e)
117
int32_t n_evals =
m_binary_evaluations
->
get_num_elements
();
118
ASSERT
(e<n_evals)
119
return
m_evaluations_results
[run*
m_num_folds
*
m_num_classes
*n_evals+fold*
m_num_classes
*n_evals+c*n_evals+e];
120
}
121
126
float64_t
get_fold_accuracy
(int32_t run, int32_t fold)
127
{
128
ASSERT
(0<=run)
129
ASSERT
(run<
m_num_runs
)
130
ASSERT
(0<=fold)
131
ASSERT
(fold<
m_num_folds
)
132
return
m_accuracies
[run*
m_num_folds
+fold];
133
}
134
139
SGMatrix<int32_t>
get_fold_conf_matrix
(int32_t run, int32_t fold)
140
{
141
ASSERT
(0<=run)
142
ASSERT
(run<
m_num_runs
)
143
ASSERT
(0<=fold)
144
ASSERT
(fold<
m_num_folds
)
145
REQUIRE
(
m_compute_conf_matrices
,
"Confusion matrices computation was not enabled\n"
)
146
return
m_conf_matrices
[run*
m_num_folds
+fold];
147
}
148
150
virtual
void
post_init
();
151
153
virtual
void
post_update_results
();
154
158
virtual
void
init_expose_labels
(
CLabels
* labels);
159
165
virtual
void
update_test_result
(
CLabels
* results,
166
const
char
* prefix=
""
);
167
173
virtual
void
update_test_true_result
(
CLabels
* results,
174
const
char
* prefix=
""
);
175
177
virtual
const
char
*
get_name
()
const
{
return
"CrossValidationMulticlassStorage"
; }
178
179
protected
:
180
182
bool
m_initialized
;
183
185
CDynamicObjectArray
*
m_binary_evaluations
;
186
188
SGVector<float64_t>
m_evaluations_results
;
189
191
SGVector<float64_t>
m_accuracies
;
192
194
bool
m_compute_ROC
;
195
197
SGMatrix<float64_t>
*
m_fold_ROC_graphs
;
198
200
bool
m_compute_PRC
;
201
203
SGMatrix<float64_t>
*
m_fold_PRC_graphs
;
204
206
bool
m_compute_conf_matrices
;
207
209
SGMatrix<int32_t>
*
m_conf_matrices
;
210
212
CMulticlassLabels
*
m_pred_labels
;
213
215
CMulticlassLabels
*
m_true_labels
;
216
218
int32_t
m_num_classes
;
219
220
};
221
222
}
223
224
#endif
/* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */
SHOGUN
Machine Learning Toolbox - Documentation