17 using namespace shogun;
25 SG_ERROR(
"Expected StreamingDenseFeatures\n");
27 SG_ERROR(
"Expected float32_t feature type\n");
32 vector<int32_t> predicts;
34 m_feats->start_parser();
35 while (m_feats->get_next_example())
37 predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
38 m_feats->release_example();
40 m_feats->end_parser();
43 for (
size_t i=0; i < predicts.size(); ++i)
50 compute_conditional_probabilities(ex);
52 for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
54 probs[it->first] = accumulate_conditional_probability(it->second);
61 stack<node_t *> nodes;
64 while (!nodes.empty())
70 nodes.push(node->
left());
71 nodes.push(node->
right());
85 if (leaf == par->
left())
102 SG_ERROR(
"Expected StreamingDenseFeatures\n");
104 SG_ERROR(
"Expected float32_t features\n");
110 SG_ERROR(
"No data features provided\n");
113 m_machines->reset_array();
119 m_feats->start_parser();
120 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
122 while (m_feats->get_next_example())
124 train_example(m_feats->get_vector(),
static_cast<int32_t
>(m_feats->get_label()));
125 m_feats->release_example();
128 if (ipass < m_num_passes-1)
129 m_feats->reset_stream();
131 m_feats->end_parser();
133 for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
148 printf(
"Empty Tree\n");
156 m_root->data.label = label;
157 m_leaves.insert(make_pair(label, m_root));
158 m_root->machine(create_machine(ex));
162 if (m_leaves.find(label) != m_leaves.end())
164 train_path(ex, m_leaves[label]);
169 while (node->
left() != NULL)
172 bool is_left = which_subtree(node, ex);
178 train_node(ex, node_label, node);
183 node = node->
right();
195 m_machines->push_back(mch);
196 left_node->
machine(m_machines->get_num_elements()-1);
197 m_leaves.insert(make_pair(left_node->
data.
label, left_node));
198 node->
left(left_node);
202 right_node->
machine(create_machine(ex));
203 m_leaves.insert(make_pair(label, right_node));
204 node->
right(right_node);
211 train_node(ex, node_label, node);
216 if (par->
left() == node)
221 train_node(ex, node_label, par);
250 m_machines->push_back(mch);
251 return m_machines->get_num_elements()-1;