24 SG_ERROR(
"Expected StreamingVwFeatures\n")
25 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
28 vector<int32_t> predicts;
30 m_feats->start_parser();
31 while (m_feats->get_next_example())
33 predicts.push_back(apply_multiclass_example(m_feats->get_example()));
34 m_feats->release_example();
36 m_feats->end_parser();
39 for (
size_t i=0; i < predicts.size(); ++i)
48 compute_conditional_probabilities(ex);
50 for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
52 probs[it->first] = accumulate_conditional_probability(it->second);
59 stack<bnode_t *> nodes;
62 while (!nodes.empty())
68 nodes.push(node->
left());
69 nodes.push(node->
right());
72 node->
data.p_right = train_node(ex, node);
83 if (leaf == par->
left())
84 prob *= (1-par->
data.p_right);
86 prob *= par->
data.p_right;
100 SG_ERROR(
"Expected StreamingVwFeatures\n")
101 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
106 SG_ERROR(
"No data features provided\n")
109 m_machines->reset_array();
115 m_feats->start_parser();
116 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
118 while (m_feats->get_next_example())
120 train_example(m_feats->get_example());
121 m_feats->release_example();
124 if (ipass < m_num_passes-1)
125 m_feats->reset_stream();
127 m_feats->end_parser();
134 int32_t label =
static_cast<int32_t
>(ex->
ld->
label);
139 m_root->data.label = label;
140 printf(
" insert %d %p\n", label, m_root);
141 m_leaves.insert(make_pair(label,(
bnode_t*) m_root));
142 m_root->machine(create_machine(ex));
146 if (m_leaves.find(label) != m_leaves.end())
148 train_path(ex, m_leaves[label]);
153 while (node->
left() != NULL)
156 bool is_left = which_subtree(node, ex);
161 train_node(ex, node);
166 node = node->
right();
169 printf(
" remove %d %p\n", node->
data.label, m_leaves[node->
data.label]);
170 m_leaves.erase(node->
data.label);
173 left_node->
data.label = node->
data.label;
174 node->
data.label = -1;
179 m_machines->push_back(vw);
180 left_node->
machine(m_machines->get_num_elements()-1);
181 printf(
" insert %d %p\n", left_node->
data.label, left_node);
182 m_leaves.insert(make_pair(left_node->
data.label, left_node));
183 node->
left(left_node);
186 right_node->
data.label = label;
187 right_node->
machine(create_machine(ex));
188 printf(
" insert %d %p\n", label, right_node);
189 m_leaves.insert(make_pair(label, right_node));
190 node->
right(right_node);
197 train_node(ex, node);
202 if (par->
left() == node)
230 m_machines->push_back(vw);
231 return m_machines->get_num_elements()-1;
void parent(CTreeMachineNode *par)
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
void machine(int32_t idx)
float64_t train_node(VwExample *ex, bnode_t *node)
static int32_t arg_max(T *vec, int32_t inc, int32_t len, T *maxv_ptr=NULL)
void train_example(VwExample *ex)
CVwLearner * get_learner()
float64_t accumulate_conditional_probability(bnode_t *leaf)
float32_t label
Label value.
Multiclass Labels for multi-class classification.
virtual void set_learner()
void right(CBinaryTreeMachineNode *r)
virtual bool train_machine(CFeatures *data)
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
virtual void train(VwExample *&ex, float32_t update)=0
virtual EFeatureClass get_feature_class() const =0
int32_t create_machine(VwExample *ex)
virtual float32_t predict_and_finalize(VwExample *ex)
void train_path(VwExample *ex, bnode_t *node)
CBinaryTreeMachineNode< VwConditionalProbabilityTreeNodeData > bnode_t
all of classes and functions are contained in the shogun namespace
bool set_int_label(int32_t idx, int32_t label)
The class Features is the base class of all feature objects.
VwLabel * ld
Label object.
float32_t eta_round
Learning rate for this round.
Class CVowpalWabbit is the implementation of the online learning algorithm used in Vowpal Wabbit...
virtual int32_t apply_multiclass_example(VwExample *ex)
void left(CBinaryTreeMachineNode *l)
void compute_conditional_probabilities(VwExample *ex)