SHOGUN
v2.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
multiclass
tree
RelaxedTree.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#ifndef RELAXEDTREE_H__
12
#define RELAXEDTREE_H__
13
14
#include <utility>
15
#include <vector>
16
17
#include <
shogun/features/DenseFeatures.h
>
18
#include <
shogun/classifier/svm/LibSVM.h
>
19
#include <
shogun/multiclass/tree/TreeMachine.h
>
20
#include <
shogun/multiclass/tree/RelaxedTreeNodeData.h
>
21
22
namespace
shogun
23
{
24
25
class
CBaseMulticlassMachine;
26
34
class
CRelaxedTree
:
public
CTreeMachine
<RelaxedTreeNodeData>
35
{
36
public
:
38
CRelaxedTree
();
39
41
virtual
~CRelaxedTree
();
42
44
virtual
const
char
*
get_name
()
const
{
return
"RelaxedTree"
; }
45
47
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
48
52
void
set_features
(
CDenseFeatures<float64_t>
*feats)
53
{
54
SG_REF
(feats);
55
SG_UNREF
(
m_feats
);
56
m_feats
= feats;
57
}
58
62
virtual
void
set_kernel
(
CKernel
*kernel)
63
{
64
SG_REF
(kernel);
65
SG_UNREF
(
m_kernel
);
66
m_kernel
= kernel;
67
}
68
73
virtual
void
set_labels
(
CLabels
* lab)
74
{
75
CMulticlassLabels
*mlab =
dynamic_cast<
CMulticlassLabels
*
>
(lab);
76
REQUIRE
(lab,
"requires MulticlassLabes\n"
);
77
78
CMachine::set_labels
(mlab);
79
m_num_classes
= mlab->
get_num_classes
();
80
}
81
85
void
set_machine_for_confusion_matrix
(
CBaseMulticlassMachine
*machine)
86
{
87
SG_REF
(machine);
88
SG_UNREF
(
m_machine_for_confusion_matrix
);
89
m_machine_for_confusion_matrix
= machine;
90
}
91
95
void
set_svm_C
(
float64_t
C)
96
{
97
m_svm_C
= C;
98
}
102
float64_t
get_svm_C
()
const
103
{
104
return
m_svm_C
;
105
}
106
110
void
set_svm_epsilon
(
float64_t
epsilon
)
111
{
112
m_svm_epsilon
=
epsilon
;
113
}
117
float64_t
get_svm_epsilon
()
const
118
{
119
return
m_svm_epsilon
;
120
}
121
127
void
set_A
(
float64_t
A)
128
{
129
m_A
= A;
130
}
134
float64_t
get_A
()
const
135
{
136
return
m_A
;
137
}
138
143
void
set_B
(int32_t B)
144
{
145
m_B
= B;
146
}
150
int32_t
get_B
()
const
151
{
152
return
m_B
;
153
}
154
158
void
set_max_num_iter
(int32_t n_iter)
159
{
160
m_max_num_iter
= n_iter;
161
}
165
int32_t
get_max_num_iter
()
const
166
{
167
return
m_max_num_iter
;
168
}
169
179
virtual
bool
train
(
CFeatures
* data=NULL)
180
{
181
return
CMachine::train
(data);
182
}
183
185
typedef
std::pair<std::pair<int32_t, int32_t>,
float64_t
>
entry_t
;
186
protected
:
193
float64_t
apply_one
(int32_t idx);
194
201
virtual
bool
train_machine
(
CFeatures
* data);
202
204
node_t
*
train_node
(
const
SGMatrix<float64_t>
&conf_mat,
SGVector<int32_t>
classes);
206
std::vector<entry_t>
init_node
(
const
SGMatrix<float64_t>
&global_conf_mat,
SGVector<int32_t>
classes);
208
SGVector<int32_t>
train_node_with_initialization
(
const
CRelaxedTree::entry_t
&mu_entry,
SGVector<int32_t>
classes,
CSVM
*svm);
209
211
float64_t
compute_score
(
SGVector<int32_t>
mu,
CSVM
*svm);
213
SGVector<int32_t>
color_label_space
(
CSVM
*svm,
SGVector<int32_t>
classes);
215
SGVector<float64_t>
eval_binary_model_K
(
CSVM
*svm);
216
218
void
enforce_balance_constraints_upper
(
SGVector<int32_t>
&mu,
SGVector<float64_t>
&delta_neg,
SGVector<float64_t>
&delta_pos, int32_t B_prime,
SGVector<float64_t>
& xi_neg_class);
220
void
enforce_balance_constraints_lower
(
SGVector<int32_t>
&mu,
SGVector<float64_t>
&delta_neg,
SGVector<float64_t>
&delta_pos, int32_t B_prime,
SGVector<float64_t>
& xi_neg_class);
221
223
int32_t
m_max_num_iter
;
225
float64_t
m_A
;
227
int32_t
m_B
;
229
float64_t
m_svm_C
;
231
float64_t
m_svm_epsilon
;
233
CKernel
*
m_kernel
;
235
CDenseFeatures<float64_t>
*
m_feats
;
237
CBaseMulticlassMachine
*
m_machine_for_confusion_matrix
;
239
int32_t
m_num_classes
;
240
};
241
242
}
/* shogun */
243
244
#endif
/* end of include guard: RELAXEDTREE_H__ */
245
SHOGUN
Machine Learning Toolbox - Documentation