SHOGUN
v2.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
multiclass
tree
ConditionalProbabilityTree.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#ifndef CONDITIONALPROBABILITYTREE_H__
12
#define CONDITIONALPROBABILITYTREE_H__
13
14
#include <map>
15
16
#include <
shogun/features/streaming/StreamingDenseFeatures.h
>
17
#include <
shogun/multiclass/tree/TreeMachine.h
>
18
#include <
shogun/multiclass/tree/ConditionalProbabilityTreeNodeData.h
>
19
20
namespace
shogun
21
{
22
31
class
CConditionalProbabilityTree
:
public
CTreeMachine
<ConditionalProbabilityTreeNodeData>
32
{
33
public
:
35
CConditionalProbabilityTree
(int32_t num_passes=1)
36
:
m_num_passes
(num_passes),
m_feats
(NULL)
37
{
38
}
39
41
virtual
~CConditionalProbabilityTree
() {
SG_UNREF
(
m_feats
); }
42
44
virtual
const
char
*
get_name
()
const
{
return
"ConditionalProbabilityTree"
; }
45
47
void
set_num_passes
(int32_t num_passes)
48
{
49
m_num_passes
= num_passes;
50
}
51
53
int32_t
get_num_passes
()
const
54
{
55
return
m_num_passes
;
56
}
57
61
void
set_features
(
CStreamingDenseFeatures<float32_t>
*feats)
62
{
63
SG_REF
(feats);
64
SG_UNREF
(
m_feats
);
65
m_feats
= feats;
66
}
67
69
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
70
74
virtual
int32_t
apply_multiclass_example
(
SGVector<float32_t>
ex);
75
77
void
print_tree
();
78
protected
:
80
virtual
bool
train_require_labels
()
const
{
return
false
; }
81
88
virtual
bool
train_machine
(
CFeatures
* data);
89
94
void
train_example
(
SGVector<float32_t>
ex, int32_t label);
95
100
void
train_path
(
SGVector<float32_t>
ex,
node_t
*
node
);
101
107
void
train_node
(
SGVector<float32_t>
ex,
float64_t
label,
node_t
*
node
);
108
113
float64_t
predict_node
(
SGVector<float32_t>
ex,
node_t
*
node
);
114
118
int32_t
create_machine
(
SGVector<float32_t>
ex);
119
125
virtual
bool
which_subtree
(
node_t
*
node
,
SGVector<float32_t>
ex)=0;
126
128
void
compute_conditional_probabilities
(
SGVector<float32_t>
ex);
129
133
float64_t
accumulate_conditional_probability
(
node_t
*leaf);
134
135
int32_t
m_num_passes
;
136
std::map<int32_t, node_t*>
m_leaves
;
137
CStreamingDenseFeatures<float32_t>
*
m_feats
;
138
};
139
140
}
/* shogun */
141
142
#endif
/* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
143
SHOGUN
Machine Learning Toolbox - Documentation