SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
latent
LatentModel.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Viktor Gal
8
* Copyright (C) 2012 Viktor Gal
9
*/
10
11
#include <
shogun/latent/LatentModel.h
>
12
#include <
shogun/labels/BinaryLabels.h
>
13
14
using namespace
shogun;
15
16
CLatentModel::CLatentModel
()
17
: m_features(NULL),
18
m_labels(NULL),
19
m_do_caching(false),
20
m_cached_psi(NULL)
21
{
22
register_parameters();
23
}
24
25
CLatentModel::CLatentModel
(
CLatentFeatures
* feats,
CLatentLabels
* labels,
bool
do_caching)
26
: m_features(feats),
27
m_labels(labels),
28
m_do_caching(do_caching),
29
m_cached_psi(NULL)
30
{
31
register_parameters();
32
SG_REF
(
m_features
);
33
SG_REF
(
m_labels
);
34
}
35
36
CLatentModel::~CLatentModel
()
37
{
38
SG_UNREF
(
m_labels
);
39
SG_UNREF
(
m_features
);
40
SG_UNREF
(
m_cached_psi
);
41
}
42
43
int32_t
CLatentModel::get_num_vectors
()
const
44
{
45
return
m_features
->
get_num_vectors
();
46
}
47
48
void
CLatentModel::set_labels
(
CLatentLabels
* labs)
49
{
50
SG_REF
(labs);
51
SG_UNREF
(
m_labels
);
52
m_labels
= labs;
53
}
54
55
CLatentLabels
*
CLatentModel::get_labels
()
const
56
{
57
SG_REF
(
m_labels
);
58
return
m_labels
;
59
}
60
61
void
CLatentModel::set_features
(
CLatentFeatures
* feats)
62
{
63
SG_REF
(feats);
64
SG_UNREF
(
m_features
);
65
m_features
= feats;
66
}
67
68
void
CLatentModel::argmax_h
(
const
SGVector<float64_t>
& w)
69
{
70
int32_t num =
get_num_vectors
();
71
CBinaryLabels
* y =
CLabelsFactory::to_binary
(
m_labels
->
get_labels
());
72
ASSERT
(num > 0)
73
ASSERT
(num ==
m_labels
->
get_num_labels
())
74
75
// argmax_h only for positive examples
76
for
(int32_t i = 0; i < num; ++i)
77
{
78
if
(y->
get_label
(i) == 1)
79
{
80
// infer h and set it for the argmax_h <w,psi(x,h)>
81
CData
* latent_data =
infer_latent_variable
(w, i);
82
m_labels
->
set_latent_label
(i, latent_data);
83
}
84
}
85
}
86
87
void
CLatentModel::register_parameters()
88
{
89
m_parameters
->
add
((
CSGObject
**) &
m_features
,
"features"
,
"Latent features"
);
90
m_parameters
->
add
((
CSGObject
**) &
m_labels
,
"labels"
,
"Latent labels"
);
91
m_parameters
->
add
((
CSGObject
**) &
m_cached_psi
,
"cached_psi"
,
"Cached PSI features after argmax_h"
);
92
m_parameters
->
add
(&
m_do_caching
,
"do_caching"
,
"Indicate whether or not do PSI vector caching after argmax_h"
);
93
}
94
95
96
CLatentFeatures
*
CLatentModel::get_features
()
const
97
{
98
SG_REF
(
m_features
);
99
return
m_features
;
100
}
101
102
void
CLatentModel::cache_psi_features
()
103
{
104
if
(
m_do_caching
)
105
{
106
if
(
m_cached_psi
)
107
SG_UNREF
(
m_cached_psi
);
108
m_cached_psi
= this->
get_psi_feature_vectors
();
109
SG_REF
(
m_cached_psi
);
110
}
111
}
112
113
CDotFeatures
*
CLatentModel::get_cached_psi_features
()
const
114
{
115
if
(
m_do_caching
)
116
{
117
SG_REF
(
m_cached_psi
);
118
return
m_cached_psi
;
119
}
120
return
NULL;
121
}
SHOGUN
Machine Learning Toolbox - Documentation