SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
labels
MulticlassLabels.cpp
Go to the documentation of this file.
1
#include <
shogun/labels/DenseLabels.h
>
2
#include <
shogun/labels/BinaryLabels.h
>
3
#include <
shogun/labels/MulticlassLabels.h
>
4
#include <
shogun/base/ParameterMap.h
>
5
6
using namespace
shogun;
7
8
CMulticlassLabels::CMulticlassLabels
() :
CDenseLabels
()
9
{
10
init();
11
}
12
13
CMulticlassLabels::CMulticlassLabels
(int32_t num_labels) :
CDenseLabels
(num_labels)
14
{
15
init();
16
}
17
18
CMulticlassLabels::CMulticlassLabels
(
const
SGVector<float64_t>
src) :
CDenseLabels
()
19
{
20
init();
21
set_labels
(src);
22
}
23
24
CMulticlassLabels::CMulticlassLabels
(
CFile
* loader) :
CDenseLabels
(loader)
25
{
26
init();
27
}
28
29
CMulticlassLabels::~CMulticlassLabels
()
30
{
31
}
32
33
void
CMulticlassLabels::init()
34
{
35
/* for this to work, migration has to be fixed */
36
// SG_ADD(&m_multiclass_confidences, "multiclass_confidences", "Vectors of "
37
// "multiclass confidences", MS_NOT_AVAILABLE);
38
39
// m_parameter_map->finalize_map();
40
41
m_multiclass_confidences
=
SGMatrix<float64_t>
();
42
}
43
44
void
CMulticlassLabels::set_multiclass_confidences
(int32_t i,
45
SGVector<float64_t>
confidences)
46
{
47
REQUIRE
(confidences.
size
()==
m_multiclass_confidences
.
num_rows
,
48
"%s::set_multiclass_confidences(): Length of confidences should "
49
"match size of the matrix"
,
get_name
());
50
51
for
(
index_t
j=0; j<confidences.
size
(); j++)
52
m_multiclass_confidences
(j,i) = confidences[j];
53
}
54
55
SGVector<float64_t>
CMulticlassLabels::get_multiclass_confidences
(int32_t i)
56
{
57
SGVector<float64_t>
confs(
m_multiclass_confidences
.
num_rows
);
58
for
(
index_t
j=0; j<confs.
size
(); j++)
59
confs[j] =
m_multiclass_confidences
(j,i);
60
61
return
confs;
62
}
63
64
void
CMulticlassLabels::allocate_confidences_for
(int32_t n_classes)
65
{
66
int32_t n_labels =
m_labels
.
size
();
67
REQUIRE
(n_labels!=0,
"%s::allocate_confidences_for(): There should be "
68
"labels to store confidences"
,
get_name
());
69
70
m_multiclass_confidences
=
SGMatrix<float64_t>
(n_classes,n_labels);
71
}
72
73
void
CMulticlassLabels::ensure_valid
(
const
char
* context)
74
{
75
CDenseLabels::ensure_valid
(context);
76
77
int32_t subset_size=
get_num_labels
();
78
for
(int32_t i=0; i<subset_size; i++)
79
{
80
int32_t real_i =
m_subset_stack
->
subset_idx_conversion
(i);
81
int32_t label = int32_t(
m_labels
[real_i]);
82
83
if
(label<0 ||
float64_t
(label)!=
m_labels
[real_i])
84
{
85
SG_ERROR
(
"%s%sMulticlass Labels must be in range 0...<nr_classes-1> and integers!\n"
,
86
context?context:
""
, context?
": "
:
""
);
87
}
88
}
89
}
90
91
ELabelType
CMulticlassLabels::get_label_type
()
const
92
{
93
return
LT_MULTICLASS
;
94
}
95
96
CBinaryLabels
*
CMulticlassLabels::get_binary_for_class
(int32_t i)
97
{
98
SGVector<float64_t>
binary_labels(
get_num_labels
());
99
100
bool
use_confidences =
false
;
101
if
((
m_multiclass_confidences
.
num_rows
!= 0) && (
m_multiclass_confidences
.
num_cols
!= 0))
102
{
103
use_confidences =
true
;
104
}
105
if
(use_confidences)
106
{
107
for
(int32_t k=0; k<binary_labels.
vlen
; k++)
108
{
109
int32_t label =
get_int_label
(k);
110
float64_t
confidence =
m_multiclass_confidences
(label,k);
111
binary_labels[k] = label == i ? confidence : -confidence;
112
}
113
}
114
else
115
{
116
for
(int32_t k=0; k<binary_labels.
vlen
; k++)
117
{
118
int32_t label =
get_int_label
(k);
119
binary_labels[k] = label == i ? +1.0 : -1.0;
120
}
121
}
122
return
new
CBinaryLabels
(binary_labels);
123
}
124
125
SGVector<float64_t>
CMulticlassLabels::get_unique_labels
()
126
{
127
/* extract all labels (copy because of possible subset) */
128
SGVector<float64_t>
unique_labels=
get_labels_copy
();
129
unique_labels.
vlen
=
SGVector<float64_t>::unique
(unique_labels.
vector
, unique_labels.
vlen
);
130
131
SGVector<float64_t>
result(unique_labels.
vlen
);
132
memcpy(result.
vector
, unique_labels.
vector
,
133
sizeof
(
float64_t
)*unique_labels.
vlen
);
134
135
return
result;
136
}
137
138
139
int32_t
CMulticlassLabels::get_num_classes
()
140
{
141
SGVector<float64_t>
unique=
get_unique_labels
();
142
return
unique.
vlen
;
143
}
SHOGUN
Machine Learning Toolbox - Documentation