SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
multiclass
MulticlassOneVsRestStrategy.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#include <
shogun/multiclass/MulticlassOneVsRestStrategy.h
>
12
#include <
shogun/labels/BinaryLabels.h
>
13
#include <
shogun/labels/MulticlassLabels.h
>
14
#include <
shogun/mathematics/Math.h
>
15
16
using namespace
shogun;
17
18
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy
()
19
:
CMulticlassStrategy
()
20
{
21
}
22
23
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy
(
EProbHeuristicType
prob_heuris)
24
:
CMulticlassStrategy
(prob_heuris)
25
{
26
}
27
28
SGVector<int32_t>
CMulticlassOneVsRestStrategy::train_prepare_next
()
29
{
30
for
(int32_t i=0; i <
m_orig_labels
->
get_num_labels
(); ++i)
31
{
32
if
(((
CMulticlassLabels
*)
m_orig_labels
)->get_int_label(i)==
m_train_iter
)
33
((
CBinaryLabels
*)
m_train_labels
)->set_label(i, +1.0);
34
else
35
((
CBinaryLabels
*)
m_train_labels
)->set_label(i, -1.0);
36
}
37
38
// increase m_train_iter *after* setting labels
39
CMulticlassStrategy::train_prepare_next
();
40
41
return
SGVector<int32_t>
();
42
}
43
44
int32_t
CMulticlassOneVsRestStrategy::decide_label
(
SGVector<float64_t>
outputs)
45
{
46
if
(
m_rejection_strategy
&&
m_rejection_strategy
->
reject
(outputs))
47
return
CDenseLabels::REJECTION_LABEL
;
48
49
return
SGVector<float64_t>::arg_max
(outputs.
vector
, 1, outputs.
vlen
);
50
}
51
52
SGVector<index_t>
CMulticlassOneVsRestStrategy::decide_label_multiple_output
(
SGVector<float64_t>
outputs, int32_t n_outputs)
53
{
54
float64_t
* outputs_ = SG_MALLOC(
float64_t
, outputs.
vlen
);
55
int32_t* indices_ = SG_MALLOC(int32_t, outputs.
vlen
);
56
for
(int32_t i=0; i<outputs.
vlen
; i++)
57
{
58
outputs_[i] = outputs[i];
59
indices_[i] = i;
60
}
61
CMath::qsort_backward_index
(outputs_,indices_,outputs.
vlen
);
62
SGVector<index_t>
result(n_outputs);
63
for
(int32_t i=0; i<n_outputs; i++)
64
result[i] = indices_[i];
65
SG_FREE(outputs_);
66
SG_FREE(indices_);
67
return
result;
68
}
69
70
void
CMulticlassOneVsRestStrategy::rescale_outputs
(
SGVector<float64_t>
outputs)
71
{
72
switch
(
get_prob_heuris_type
())
73
{
74
case
OVA_NORM
:
75
rescale_heuris_norm
(outputs);
76
break
;
77
case
OVA_SOFTMAX
:
78
SG_ERROR
(
"%s::rescale_outputs(): Need to specify sigmoid parameters!\n"
,
get_name
());
79
break
;
80
case
PROB_HEURIS_NONE
:
81
break
;
82
default
:
83
SG_ERROR
(
"%s::rescale_outputs(): Unknown OVA probability heuristic type!\n"
,
get_name
());
84
break
;
85
}
86
}
87
88
void
CMulticlassOneVsRestStrategy::rescale_outputs
(
SGVector<float64_t>
outputs,
89
const
SGVector<float64_t>
As,
const
SGVector<float64_t>
Bs)
90
{
91
if
(
get_prob_heuris_type
()==
OVA_SOFTMAX
)
92
rescale_heuris_softmax
(outputs,As,Bs);
93
else
94
rescale_outputs
(outputs);
95
}
96
97
void
CMulticlassOneVsRestStrategy::rescale_heuris_norm
(
SGVector<float64_t>
outputs)
98
{
99
if
(
m_num_classes
!= outputs.
vlen
)
100
{
101
SG_ERROR
(
"%s::rescale_heuris_norm(): size(outputs) = %d != m_num_classes = %d\n"
,
102
get_name
(), outputs.
vlen
,
m_num_classes
);
103
}
104
105
float64_t
norm
=
SGVector<float64_t>::sum
(outputs);
106
norm += 1E-10;
107
for
(int32_t i=0; i<outputs.
vlen
; i++)
108
outputs[i] /= norm;
109
}
110
111
void
CMulticlassOneVsRestStrategy::rescale_heuris_softmax
(
SGVector<float64_t>
outputs,
112
const
SGVector<float64_t>
As,
const
SGVector<float64_t>
Bs)
113
{
114
if
(
m_num_classes
!= outputs.
vlen
)
115
{
116
SG_ERROR
(
"%s::rescale_heuris_softmax(): size(outputs) = %d != m_num_classes = %d\n"
,
117
get_name
(), outputs.
vlen
,
m_num_classes
);
118
}
119
120
for
(int32_t i=0; i<outputs.
vlen
; i++)
121
outputs[i] =
CMath::exp
(-As[i]*outputs[i]-Bs[i]);
122
123
float64_t
norm
=
SGVector<float64_t>::sum
(outputs);
124
norm += 1E-10;
125
for
(int32_t i=0; i<outputs.
vlen
; i++)
126
outputs[i] /= norm;
127
}
SHOGUN
Machine Learning Toolbox - Documentation