SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
machine
gp
InferenceMethod.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Roman Votyakov
8
* Written (W) 2013 Heiko Strathmann
9
* Copyright (C) 2012 Jacob Walker
10
* Copyright (C) 2013 Roman Votyakov
11
*/
12
13
#ifndef CINFERENCEMETHOD_H_
14
#define CINFERENCEMETHOD_H_
15
16
#include <
shogun/lib/config.h
>
17
18
#ifdef HAVE_EIGEN3
19
20
#include <
shogun/base/SGObject.h
>
21
#include <
shogun/kernel/Kernel.h
>
22
#include <
shogun/features/Features.h
>
23
#include <
shogun/labels/Labels.h
>
24
#include <
shogun/machine/gp/LikelihoodModel.h
>
25
#include <
shogun/machine/gp/MeanFunction.h
>
26
#include <
shogun/evaluation/DifferentiableFunction.h
>
27
28
namespace
shogun
29
{
30
32
enum
EInferenceType
33
{
34
INF_NONE
=0,
35
INF_EXACT
=10,
36
INF_FITC
=20,
37
INF_LAPLACIAN
=30,
38
INF_EP
=40
39
};
40
49
class
CInferenceMethod
:
public
CDifferentiableFunction
50
{
51
public
:
53
CInferenceMethod
();
54
63
CInferenceMethod
(
CKernel
* kernel,
CFeatures
* features,
64
CMeanFunction
* mean,
CLabels
* labels,
CLikelihoodModel
* model);
65
66
virtual
~CInferenceMethod
();
67
72
virtual
EInferenceType
get_inference_type
()
const
{
return
INF_NONE
; }
73
85
virtual
float64_t
get_negative_log_marginal_likelihood
()=0;
86
122
float64_t
get_marginal_likelihood_estimate
(int32_t num_importance_samples=1,
123
float64_t
ridge_size=1e-15);
124
138
virtual
CMap<TParameter*, SGVector<float64_t>
>*
get_negative_log_marginal_likelihood_derivatives
(
139
CMap<TParameter*, CSGObject*>
* parameters);
140
151
virtual
SGVector<float64_t>
get_alpha
()=0;
152
164
virtual
SGMatrix<float64_t>
get_cholesky
()=0;
165
177
virtual
SGVector<float64_t>
get_diagonal_vector
()=0;
178
194
virtual
SGVector<float64_t>
get_posterior_mean
()=0;
195
211
virtual
SGMatrix<float64_t>
get_posterior_covariance
()=0;
212
220
virtual
CMap<TParameter*, SGVector<float64_t>
>*
get_gradient
(
221
CMap<TParameter*, CSGObject*>
* parameters)
222
{
223
return
get_negative_log_marginal_likelihood_derivatives
(parameters);
224
}
225
230
virtual
SGVector<float64_t>
get_value
()
231
{
232
SGVector<float64_t>
result(1);
233
result[0]=
get_negative_log_marginal_likelihood
();
234
return
result;
235
}
236
241
virtual
CFeatures
*
get_features
() {
SG_REF
(
m_features
);
return
m_features
; }
242
247
virtual
void
set_features
(
CFeatures
* feat)
248
{
249
SG_REF
(feat);
250
SG_UNREF
(
m_features
);
251
m_features
=feat;
252
}
253
258
virtual
CKernel
*
get_kernel
() {
SG_REF
(
m_kernel
);
return
m_kernel
; }
259
264
virtual
void
set_kernel
(
CKernel
* kern)
265
{
266
SG_REF
(kern);
267
SG_UNREF
(
m_kernel
);
268
m_kernel
=kern;
269
}
270
275
virtual
CMeanFunction
*
get_mean
() {
SG_REF
(
m_mean
);
return
m_mean
; }
276
281
virtual
void
set_mean
(
CMeanFunction
* m)
282
{
283
SG_REF
(m);
284
SG_UNREF
(
m_mean
);
285
m_mean
=m;
286
}
287
292
virtual
CLabels
*
get_labels
() {
SG_REF
(
m_labels
);
return
m_labels
; }
293
298
virtual
void
set_labels
(
CLabels
* lab)
299
{
300
SG_REF
(lab);
301
SG_UNREF
(
m_labels
);
302
m_labels
=lab;
303
}
304
309
CLikelihoodModel
*
get_model
() {
SG_REF
(
m_model
);
return
m_model
; }
310
315
virtual
void
set_model
(
CLikelihoodModel
* mod)
316
{
317
SG_REF
(mod);
318
SG_UNREF
(
m_model
);
319
m_model
=mod;
320
}
321
326
virtual
float64_t
get_scale
()
const
{
return
m_scale
; }
327
332
virtual
void
set_scale
(
float64_t
scale) {
m_scale
=scale; }
333
339
virtual
bool
supports_regression
()
const
{
return
false
; }
340
346
virtual
bool
supports_binary
()
const
{
return
false
; }
347
353
virtual
bool
supports_multiclass
()
const
{
return
false
; }
354
356
virtual
void
update
();
357
358
protected
:
360
virtual
void
check_members
()
const
;
361
363
virtual
void
update_alpha
()=0;
364
366
virtual
void
update_chol
()=0;
367
371
virtual
void
update_deriv
()=0;
372
374
virtual
void
update_train_kernel
();
375
383
virtual
SGVector<float64_t>
get_derivative_wrt_inference_method
(
384
const
TParameter
* param)=0;
385
393
virtual
SGVector<float64_t>
get_derivative_wrt_likelihood_model
(
394
const
TParameter
* param)=0;
395
403
virtual
SGVector<float64_t>
get_derivative_wrt_kernel
(
404
const
TParameter
* param)=0;
405
413
virtual
SGVector<float64_t>
get_derivative_wrt_mean
(
414
const
TParameter
* param)=0;
415
419
static
void
*
get_derivative_helper
(
void
* p);
420
421
private
:
422
void
init();
423
424
protected
:
426
CKernel
*
m_kernel
;
427
429
CMeanFunction
*
m_mean
;
430
432
CLikelihoodModel
*
m_model
;
433
435
CFeatures
*
m_features
;
436
438
CLabels
*
m_labels
;
439
441
SGVector<float64_t>
m_alpha
;
442
444
SGMatrix<float64_t>
m_L
;
445
447
float64_t
m_scale
;
448
450
SGMatrix<float64_t>
m_ktrtr
;
451
};
452
}
453
#endif
/* HAVE_EIGEN3 */
454
#endif
/* CINFERENCEMETHOD_H_ */
SHOGUN
Machine Learning Toolbox - Documentation