SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
machine
StructuredOutputMachine.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Thoralf Klein
8
* Written (W) 2012 Fernando José Iglesias García
9
* Copyright (C) 2012 Fernando José Iglesias García
10
*/
11
12
#include <
shogun/machine/StructuredOutputMachine.h
>
13
14
using namespace
shogun;
15
16
CStructuredOutputMachine::CStructuredOutputMachine
()
17
:
CMachine
(), m_model(NULL), m_surrogate_loss(NULL)
18
{
19
register_parameters();
20
}
21
22
CStructuredOutputMachine::CStructuredOutputMachine
(
23
CStructuredModel
* model,
24
CStructuredLabels
* labs)
25
:
CMachine
(), m_model(model), m_surrogate_loss(NULL)
26
{
27
SG_REF
(
m_model
);
28
set_labels
(labs);
29
register_parameters();
30
}
31
32
CStructuredOutputMachine::~CStructuredOutputMachine
()
33
{
34
SG_UNREF
(
m_model
);
35
SG_UNREF
(
m_surrogate_loss
);
36
}
37
38
void
CStructuredOutputMachine::set_model
(
CStructuredModel
* model)
39
{
40
SG_REF
(model);
41
SG_UNREF
(
m_model
);
42
m_model
= model;
43
}
44
45
CStructuredModel
*
CStructuredOutputMachine::get_model
()
const
46
{
47
SG_REF
(
m_model
);
48
return
m_model
;
49
}
50
51
void
CStructuredOutputMachine::register_parameters()
52
{
53
SG_ADD
((
CSGObject
**)&
m_model
,
"m_model"
,
"Structured model"
,
MS_NOT_AVAILABLE
);
54
SG_ADD
((
CSGObject
**)&
m_surrogate_loss
,
"m_surrogate_loss"
,
"Surrogate loss"
,
MS_NOT_AVAILABLE
);
55
}
56
57
void
CStructuredOutputMachine::set_labels
(
CLabels
* lab)
58
{
59
CMachine::set_labels
(lab);
60
REQUIRE
(
m_model
!= NULL,
"please call set_model() before set_labels()\n"
);
61
m_model
->
set_labels
(
CLabelsFactory::to_structured
(lab));
62
}
63
64
void
CStructuredOutputMachine::set_features
(
CFeatures
* f)
65
{
66
m_model
->
set_features
(f);
67
}
68
69
CFeatures
*
CStructuredOutputMachine::get_features
()
const
70
{
71
return
m_model
->
get_features
();
72
}
73
74
void
CStructuredOutputMachine::set_surrogate_loss
(
CLossFunction
* loss)
75
{
76
SG_REF
(loss);
77
SG_UNREF
(
m_surrogate_loss
);
78
m_surrogate_loss
= loss;
79
}
80
81
CLossFunction
*
CStructuredOutputMachine::get_surrogate_loss
()
const
82
{
83
SG_REF
(
m_surrogate_loss
);
84
return
m_surrogate_loss
;
85
}
86
87
float64_t
CStructuredOutputMachine::risk_nslack_margin_rescale
(
float64_t
* subgrad,
float64_t
* W,
TMultipleCPinfo
* info)
88
{
89
int32_t dim =
m_model
->
get_dim
();
90
91
int32_t from=0, to=0;
92
CFeatures
* features =
get_features
();
93
if
(info)
94
{
95
from = info->
m_from
;
96
to = (info->
m_N
== 0) ? features->
get_num_vectors
() : from+info->
m_N
;
97
}
98
else
99
{
100
from = 0;
101
to = features->
get_num_vectors
();
102
}
103
SG_UNREF
(features);
104
105
float64_t
R = 0.0;
106
for
(int32_t i=0; i<dim; i++)
107
subgrad[i] = 0;
108
109
for
(int32_t i=from; i<to; i++)
110
{
111
CResultSet
* result =
m_model
->
argmax
(
SGVector<float64_t>
(W,dim,
false
), i,
true
);
112
SGVector<float64_t>
psi_pred = result->
psi_pred
;
113
SGVector<float64_t>
psi_truth = result->
psi_truth
;
114
SGVector<float64_t>::vec1_plus_scalar_times_vec2
(subgrad, 1.0, psi_pred.
vector
, dim);
115
SGVector<float64_t>::vec1_plus_scalar_times_vec2
(subgrad, -1.0, psi_truth.
vector
, dim);
116
R += result->
score
;
117
SG_UNREF
(result);
118
}
119
120
return
R;
121
}
122
123
float64_t
CStructuredOutputMachine::risk_nslack_slack_rescale
(
float64_t
* subgrad,
float64_t
* W,
TMultipleCPinfo
* info)
124
{
125
SG_ERROR
(
"%s::risk_nslack_slack_rescale() has not been implemented!\n"
,
get_name
());
126
return
0.0;
127
}
128
129
float64_t
CStructuredOutputMachine::risk_1slack_margin_rescale
(
float64_t
* subgrad,
float64_t
* W,
TMultipleCPinfo
* info)
130
{
131
SG_ERROR
(
"%s::risk_1slack_margin_rescale() has not been implemented!\n"
,
get_name
());
132
return
0.0;
133
}
134
135
float64_t
CStructuredOutputMachine::risk_1slack_slack_rescale
(
float64_t
* subgrad,
float64_t
* W,
TMultipleCPinfo
* info)
136
{
137
SG_ERROR
(
"%s::risk_1slack_slack_rescale() has not been implemented!\n"
,
get_name
());
138
return
0.0;
139
}
140
141
float64_t
CStructuredOutputMachine::risk_customized_formulation
(
float64_t
* subgrad,
float64_t
* W,
TMultipleCPinfo
* info)
142
{
143
SG_ERROR
(
"%s::risk_customized_formulation() has not been implemented!\n"
,
get_name
());
144
return
0.0;
145
}
146
147
float64_t
CStructuredOutputMachine::risk
(
float64_t
* subgrad,
float64_t
* W,
148
TMultipleCPinfo
* info,
EStructRiskType
rtype)
149
{
150
float64_t
ret = 0.0;
151
switch
(rtype)
152
{
153
case
N_SLACK_MARGIN_RESCALING
:
154
ret =
risk_nslack_margin_rescale
(subgrad, W, info);
155
break
;
156
case
N_SLACK_SLACK_RESCALING
:
157
ret =
risk_nslack_slack_rescale
(subgrad, W, info);
158
break
;
159
case
ONE_SLACK_MARGIN_RESCALING
:
160
ret =
risk_1slack_margin_rescale
(subgrad, W, info);
161
break
;
162
case
ONE_SLACK_SLACK_RESCALING
:
163
ret =
risk_1slack_slack_rescale
(subgrad, W, info);
164
break
;
165
case
CUSTOMIZED_RISK
:
166
ret =
risk_customized_formulation
(subgrad, W, info);
167
break
;
168
default
:
169
SG_ERROR
(
"%s::risk(): cannot recognize the risk type!\n"
,
get_name
());
170
ret = -1;
171
break
;
172
}
173
return
ret;
174
}
SHOGUN
Machine Learning Toolbox - Documentation