SHOGUN
v2.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
multiclass
MulticlassOneVsOneStrategy.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#include <
shogun/multiclass/MulticlassOneVsOneStrategy.h
>
12
#include <
shogun/labels/BinaryLabels.h
>
13
#include <
shogun/labels/MulticlassLabels.h
>
14
15
using namespace
shogun;
16
17
CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy
()
18
:
CMulticlassStrategy
(), m_num_machines(0)
19
{
20
}
21
22
void
CMulticlassOneVsOneStrategy::train_start
(
CMulticlassLabels
*orig_labels,
CBinaryLabels
*train_labels)
23
{
24
CMulticlassStrategy::train_start
(orig_labels, train_labels);
25
m_num_machines
=
m_num_classes
*(
m_num_classes
-1)/2;
26
27
m_train_pair_idx_1
= 0;
28
m_train_pair_idx_2
= 1;
29
}
30
31
bool
CMulticlassOneVsOneStrategy::train_has_more
()
32
{
33
return
m_train_iter
<
m_num_machines
;
34
}
35
36
SGVector<int32_t>
CMulticlassOneVsOneStrategy::train_prepare_next
()
37
{
38
CMulticlassStrategy::train_prepare_next
();
39
40
SGVector<int32_t>
subset(
m_orig_labels
->
get_num_labels
());
41
int32_t tot=0;
42
for
(int32_t k=0; k <
m_orig_labels
->
get_num_labels
(); ++k)
43
{
44
if
(((
CMulticlassLabels
*)
m_orig_labels
)->get_int_label(k)==
m_train_pair_idx_1
)
45
{
46
((
CBinaryLabels
*)
m_train_labels
)->set_label(k, +1.0);
47
subset[tot]=k;
48
tot++;
49
}
50
else
if
(((
CMulticlassLabels
*)
m_orig_labels
)->get_int_label(k)==
m_train_pair_idx_2
)
51
{
52
((
CBinaryLabels
*)
m_train_labels
)->set_label(k, -1.0);
53
subset[tot]=k;
54
tot++;
55
}
56
}
57
58
m_train_pair_idx_2
++;
59
if
(
m_train_pair_idx_2
>=
m_num_classes
)
60
{
61
m_train_pair_idx_1
++;
62
m_train_pair_idx_2
=
m_train_pair_idx_1
+1;
63
}
64
65
subset.resize_vector(tot);
66
return
subset;
67
}
68
69
int32_t
CMulticlassOneVsOneStrategy::decide_label
(
SGVector<float64_t>
outputs)
70
{
71
int32_t s=0;
72
SGVector<int32_t>
votes(
m_num_classes
);
73
SGVector<int32_t>
dec_vals(
m_num_classes
);
74
votes.
zero
();
75
dec_vals.
zero
();
76
77
for
(int32_t i=0; i<
m_num_classes
; i++)
78
{
79
for
(int32_t j=i+1; j<
m_num_classes
; j++)
80
{
81
if
(outputs[s]>0)
82
{
83
votes[i]++;
84
dec_vals[i] +=
CMath::abs
(outputs[s]);
85
}
86
else
87
{
88
votes[j]++;
89
dec_vals[j] +=
CMath::abs
(outputs[s]);
90
}
91
s++;
92
}
93
}
94
95
int32_t i_max=0;
96
int32_t vote_max=-1;
97
float64_t
dec_val_max=-1;
98
99
for
(int32_t i=0; i <
m_num_classes
; ++i)
100
{
101
if
(votes[i] > vote_max)
102
{
103
i_max = i;
104
vote_max = votes[i];
105
dec_val_max = dec_vals[i];
106
}
107
else
if
(votes[i] == vote_max)
108
{
109
if
(dec_vals[i] > dec_val_max)
110
{
111
i_max = i;
112
dec_val_max = dec_vals[i];
113
}
114
}
115
}
116
117
return
i_max;
118
}
SHOGUN
Machine Learning Toolbox - Documentation