SHOGUN
v2.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
machine
Machine.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 1999-2009 Soeren Sonnenburg
8
* Written (W) 2011-2012 Heiko Strathmann
9
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
10
*/
11
12
#include <
shogun/machine/Machine.h
>
13
#include <
shogun/base/Parameter.h
>
14
#include <
shogun/mathematics/Math.h
>
15
#include <
shogun/base/ParameterMap.h
>
16
17
using namespace
shogun;
18
19
CMachine::CMachine
() :
CSGObject
(), m_max_train_time(0), m_labels(NULL),
20
m_solver_type(
ST_AUTO
)
21
{
22
m_data_locked
=
false
;
23
m_store_model_features
=
false
;
24
25
SG_ADD
(&
m_max_train_time
,
"max_train_time"
,
26
"Maximum training time."
,
MS_NOT_AVAILABLE
);
27
SG_ADD
((
machine_int_t
*) &
m_solver_type
,
"solver_type"
,
28
"Type of solver."
,
MS_NOT_AVAILABLE
);
29
SG_ADD
((
CSGObject
**) &
m_labels
,
"labels"
,
30
"Labels to be used."
,
MS_NOT_AVAILABLE
);
31
SG_ADD
(&
m_store_model_features
,
"store_model_features"
,
32
"Should feature data of model be stored after training?"
,
MS_NOT_AVAILABLE
);
33
SG_ADD
(&
m_data_locked
,
"data_locked"
,
34
"Indicates whether data is locked"
,
MS_NOT_AVAILABLE
);
35
36
m_parameter_map
->
put
(
37
new
SGParamInfo
(
"data_locked"
, CT_SCALAR, ST_NONE, PT_BOOL, 1),
38
new
SGParamInfo
()
39
);
40
41
m_parameter_map
->
finalize_map
();
42
}
43
44
CMachine::~CMachine
()
45
{
46
SG_UNREF
(
m_labels
);
47
}
48
49
bool
CMachine::train
(
CFeatures
* data)
50
{
51
/* not allowed to train on locked data */
52
if
(
m_data_locked
)
53
{
54
SG_ERROR
(
"%s::train data_lock() was called, only train_locked() is"
55
" possible. Call data_unlock if you want to call train()\n"
,
56
get_name
());
57
}
58
59
if
(
train_require_labels
())
60
{
61
if
(
m_labels
== NULL)
62
SG_ERROR
(
"%s@%p: No labels given"
,
get_name
(),
this
);
63
64
m_labels
->
ensure_valid
(
get_name
());
65
}
66
67
bool
result =
train_machine
(data);
68
69
if
(
m_store_model_features
)
70
store_model_features
();
71
72
return
result;
73
}
74
75
void
CMachine::set_labels
(
CLabels
* lab)
76
{
77
if
(lab != NULL)
78
if
(!
is_label_valid
(lab))
79
SG_ERROR
(
"Invalid label for %s"
,
get_name
());
80
81
SG_UNREF
(
m_labels
);
82
SG_REF
(lab);
83
m_labels
= lab;
84
}
85
86
CLabels
*
CMachine::get_labels
()
87
{
88
SG_REF
(
m_labels
);
89
return
m_labels
;
90
}
91
92
void
CMachine::set_max_train_time
(
float64_t
t)
93
{
94
m_max_train_time
= t;
95
}
96
97
float64_t
CMachine::get_max_train_time
()
98
{
99
return
m_max_train_time
;
100
}
101
102
EMachineType
CMachine::get_classifier_type
()
103
{
104
return
CT_NONE
;
105
}
106
107
void
CMachine::set_solver_type
(
ESolverType
st)
108
{
109
m_solver_type
= st;
110
}
111
112
ESolverType
CMachine::get_solver_type
()
113
{
114
return
m_solver_type
;
115
}
116
117
void
CMachine::set_store_model_features
(
bool
store_model)
118
{
119
m_store_model_features
= store_model;
120
}
121
122
void
CMachine::data_lock
(
CLabels
* labs,
CFeatures
* features)
123
{
124
SG_DEBUG
(
"entering %s::data_lock\n"
,
get_name
());
125
if
(!
supports_locking
())
126
{
127
{
128
SG_ERROR
(
"%s::data_lock(): Machine does not support data locking!\n"
,
129
get_name
());
130
}
131
}
132
133
if
(!labs)
134
{
135
SG_ERROR
(
"%s::data_lock() is not possible will NULL labels!\n"
,
136
get_name
());
137
}
138
139
/* first set labels */
140
set_labels
(labs);
141
142
if
(
m_data_locked
)
143
{
144
SG_ERROR
(
"%s::data_lock() was already called. Dont lock twice!"
,
145
get_name
());
146
}
147
148
m_data_locked
=
true
;
149
post_lock
(labs,features);
150
SG_DEBUG
(
"leaving %s::data_lock\n"
,
get_name
());
151
}
152
153
void
CMachine::data_unlock
()
154
{
155
SG_DEBUG
(
"entering %s::data_lock\n"
,
get_name
());
156
if
(
m_data_locked
)
157
m_data_locked
=
false
;
158
159
SG_DEBUG
(
"leaving %s::data_lock\n"
,
get_name
());
160
}
161
162
CLabels
*
CMachine::apply
(
CFeatures
* data)
163
{
164
SG_DEBUG
(
"entering %s::apply(%s at %p)\n"
,
165
get_name
(), data ? data->
get_name
() :
"NULL"
, data);
166
167
CLabels
* result=NULL;
168
169
switch
(
get_machine_problem_type
())
170
{
171
case
PT_BINARY
:
172
result=
apply_binary
(data);
173
break
;
174
case
PT_REGRESSION
:
175
result=
apply_regression
(data);
176
break
;
177
case
PT_MULTICLASS
:
178
result=
apply_multiclass
(data);
179
break
;
180
case
PT_STRUCTURED
:
181
result=
apply_structured
(data);
182
break
;
183
case
PT_LATENT
:
184
result=
apply_latent
(data);
185
break
;
186
default
:
187
SG_ERROR
(
"Unknown problem type"
);
188
break
;
189
}
190
191
SG_DEBUG
(
"leaving %s::apply(%s at %p)\n"
,
192
get_name
(), data ? data->
get_name
() :
"NULL"
, data);
193
194
return
result;
195
}
196
197
CLabels
*
CMachine::apply_locked
(
SGVector<index_t>
indices)
198
{
199
switch
(
get_machine_problem_type
())
200
{
201
case
PT_BINARY
:
202
return
apply_locked_binary
(indices);
203
case
PT_REGRESSION
:
204
return
apply_locked_regression
(indices);
205
case
PT_MULTICLASS
:
206
return
apply_locked_multiclass
(indices);
207
case
PT_STRUCTURED
:
208
return
apply_locked_structured
(indices);
209
case
PT_LATENT
:
210
return
apply_locked_latent
(indices);
211
default
:
212
SG_ERROR
(
"Unknown problem type"
);
213
break
;
214
}
215
return
NULL;
216
}
217
218
CBinaryLabels
*
CMachine::apply_binary
(
CFeatures
* data)
219
{
220
SG_ERROR
(
"This machine does not support apply_binary()\n"
);
221
return
NULL;
222
}
223
224
CRegressionLabels
*
CMachine::apply_regression
(
CFeatures
* data)
225
{
226
SG_ERROR
(
"This machine does not support apply_regression()\n"
);
227
return
NULL;
228
}
229
230
CMulticlassLabels
*
CMachine::apply_multiclass
(
CFeatures
* data)
231
{
232
SG_ERROR
(
"This machine does not support apply_multiclass()\n"
);
233
return
NULL;
234
}
235
236
CStructuredLabels
*
CMachine::apply_structured
(
CFeatures
* data)
237
{
238
SG_ERROR
(
"This machine does not support apply_structured()\n"
);
239
return
NULL;
240
}
241
242
CLatentLabels
*
CMachine::apply_latent
(
CFeatures
* data)
243
{
244
SG_ERROR
(
"This machine does not support apply_latent()\n"
);
245
return
NULL;
246
}
247
248
CBinaryLabels
*
CMachine::apply_locked_binary
(
SGVector<index_t>
indices)
249
{
250
SG_ERROR
(
"apply_locked_binary(SGVector<index_t>) is not yet implemented "
251
"for %s\n"
,
get_name
());
252
return
NULL;
253
}
254
255
CRegressionLabels
*
CMachine::apply_locked_regression
(
SGVector<index_t>
indices)
256
{
257
SG_ERROR
(
"apply_locked_regression(SGVector<index_t>) is not yet implemented "
258
"for %s\n"
,
get_name
());
259
return
NULL;
260
}
261
262
CMulticlassLabels
*
CMachine::apply_locked_multiclass
(
SGVector<index_t>
indices)
263
{
264
SG_ERROR
(
"apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
265
"for %s\n"
,
get_name
());
266
return
NULL;
267
}
268
269
CStructuredLabels
*
CMachine::apply_locked_structured
(
SGVector<index_t>
indices)
270
{
271
SG_ERROR
(
"apply_locked_structured(SGVector<index_t>) is not yet implemented "
272
"for %s\n"
,
get_name
());
273
return
NULL;
274
}
275
276
CLatentLabels
*
CMachine::apply_locked_latent
(
SGVector<index_t>
indices)
277
{
278
SG_ERROR
(
"apply_locked_latent(SGVector<index_t>) is not yet implemented "
279
"for %s\n"
,
get_name
());
280
return
NULL;
281
}
282
283
SHOGUN
Machine Learning Toolbox - Documentation