15 using namespace shogun;
23 SG_ERROR(
"Expected StreamingVwFeatures\n");
24 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
27 vector<int32_t> predicts;
29 m_feats->start_parser();
30 while (m_feats->get_next_example())
32 predicts.push_back(apply_multiclass_example(m_feats->get_example()));
33 m_feats->release_example();
35 m_feats->end_parser();
38 for (
size_t i=0; i < predicts.size(); ++i)
47 compute_conditional_probabilities(ex);
49 for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
51 probs[it->first] = accumulate_conditional_probability(it->second);
58 stack<node_t *> nodes;
61 while (!nodes.empty())
67 nodes.push(node->
left());
68 nodes.push(node->
right());
82 if (leaf == par->
left())
99 SG_ERROR(
"Expected StreamingVwFeatures\n");
100 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
105 SG_ERROR(
"No data features provided\n");
108 m_machines->reset_array();
114 m_feats->start_parser();
115 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
117 while (m_feats->get_next_example())
119 train_example(m_feats->get_example());
120 m_feats->release_example();
123 if (ipass < m_num_passes-1)
124 m_feats->reset_stream();
126 m_feats->end_parser();
133 int32_t label =
static_cast<int32_t
>(ex->
ld->
label);
138 m_root->data.label = label;
139 printf(
" insert %d %p\n", label, m_root);
140 m_leaves.insert(make_pair(label, m_root));
141 m_root->machine(create_machine(ex));
145 if (m_leaves.find(label) != m_leaves.end())
147 train_path(ex, m_leaves[label]);
152 while (node->
left() != NULL)
155 bool is_left = which_subtree(node, ex);
160 train_node(ex, node);
165 node = node->
right();
178 m_machines->push_back(vw);
179 left_node->
machine(m_machines->get_num_elements()-1);
180 printf(
" insert %d %p\n", left_node->
data.
label, left_node);
181 m_leaves.insert(make_pair(left_node->
data.
label, left_node));
182 node->
left(left_node);
186 right_node->
machine(create_machine(ex));
187 printf(
" insert %d %p\n", label, right_node);
188 m_leaves.insert(make_pair(label, right_node));
189 node->
right(right_node);
196 train_node(ex, node);
201 if (par->
left() == node)
229 m_machines->push_back(vw);
230 return m_machines->get_num_elements()-1;