SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
VwConditionalProbabilityTree.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #include <vector>
12 #include <stack>
15 
16 using namespace shogun;
17 using namespace std;
18 
20 {
21  if (data)
22  {
23  if (data->get_feature_class() != C_STREAMING_VW)
24  SG_ERROR("Expected StreamingVwFeatures\n")
25  set_features(dynamic_cast<CStreamingVwFeatures*>(data));
26  }
27 
28  vector<int32_t> predicts;
29 
30  m_feats->start_parser();
31  while (m_feats->get_next_example())
32  {
33  predicts.push_back(apply_multiclass_example(m_feats->get_example()));
34  m_feats->release_example();
35  }
36  m_feats->end_parser();
37 
38  CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
39  for (size_t i=0; i < predicts.size(); ++i)
40  labels->set_int_label(i, predicts[i]);
41  return labels;
42 }
43 
45 {
46  ex->ld->label = FLT_MAX; // this will disable VW learning from this example
47 
48  compute_conditional_probabilities(ex);
49  SGVector<float64_t> probs(m_leaves.size());
50  for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
51  {
52  probs[it->first] = accumulate_conditional_probability(it->second);
53  }
54  return CMath::arg_max(probs.vector, 1, probs.vlen);
55 }
56 
58 {
59  stack<bnode_t *> nodes;
60  nodes.push((bnode_t*) m_root);
61 
62  while (!nodes.empty())
63  {
64  bnode_t *node = nodes.top();
65  nodes.pop();
66  if (node->left())
67  {
68  nodes.push(node->left());
69  nodes.push(node->right());
70 
71  // don't calculate for leaf
72  node->data.p_right = train_node(ex, node);
73  }
74  }
75 }
76 
78 {
79  float64_t prob = 1;
80  bnode_t *par = (bnode_t*) leaf->parent();
81  while (par != NULL)
82  {
83  if (leaf == par->left())
84  prob *= (1-par->data.p_right);
85  else
86  prob *= par->data.p_right;
87 
88  leaf = par;
89  par = (bnode_t*) leaf->parent();
90  }
91 
92  return prob;
93 }
94 
96 {
97  if (data)
98  {
99  if (data->get_feature_class() != C_STREAMING_VW)
100  SG_ERROR("Expected StreamingVwFeatures\n")
101  set_features(dynamic_cast<CStreamingVwFeatures*>(data));
102  }
103  else
104  {
105  if (!m_feats)
106  SG_ERROR("No data features provided\n")
107  }
108 
109  m_machines->reset_array();
110  SG_UNREF(m_root);
111  m_root = NULL;
112 
113  m_leaves.clear();
114 
115  m_feats->start_parser();
116  for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
117  {
118  while (m_feats->get_next_example())
119  {
120  train_example(m_feats->get_example());
121  m_feats->release_example();
122  }
123 
124  if (ipass < m_num_passes-1)
125  m_feats->reset_stream();
126  }
127  m_feats->end_parser();
128 
129  return true;
130 }
131 
133 {
134  int32_t label = static_cast<int32_t>(ex->ld->label);
135 
136  if (m_root == NULL)
137  {
138  m_root = new bnode_t();
139  m_root->data.label = label;
140  printf(" insert %d %p\n", label, m_root);
141  m_leaves.insert(make_pair(label,(bnode_t*) m_root));
142  m_root->machine(create_machine(ex));
143  return;
144  }
145 
146  if (m_leaves.find(label) != m_leaves.end())
147  {
148  train_path(ex, m_leaves[label]);
149  }
150  else
151  {
152  bnode_t *node = (bnode_t*) m_root;
153  while (node->left() != NULL)
154  {
155  // not a leaf
156  bool is_left = which_subtree(node, ex);
157  if (is_left)
158  ex->ld->label = 0;
159  else
160  ex->ld->label = 1;
161  train_node(ex, node);
162 
163  if (is_left)
164  node = node->left();
165  else
166  node = node->right();
167  }
168 
169  printf(" remove %d %p\n", node->data.label, m_leaves[node->data.label]);
170  m_leaves.erase(node->data.label);
171 
172  bnode_t *left_node = new bnode_t();
173  left_node->data.label = node->data.label;
174  node->data.label = -1;
175  CVowpalWabbit *node_vw = dynamic_cast<CVowpalWabbit *>(m_machines->get_element(node->machine()));
176  CVowpalWabbit *vw = new CVowpalWabbit(node_vw);
177  SG_UNREF(node_vw);
178  vw->set_learner();
179  m_machines->push_back(vw);
180  left_node->machine(m_machines->get_num_elements()-1);
181  printf(" insert %d %p\n", left_node->data.label, left_node);
182  m_leaves.insert(make_pair(left_node->data.label, left_node));
183  node->left(left_node);
184 
185  bnode_t *right_node = new bnode_t();
186  right_node->data.label = label;
187  right_node->machine(create_machine(ex));
188  printf(" insert %d %p\n", label, right_node);
189  m_leaves.insert(make_pair(label, right_node));
190  node->right(right_node);
191  }
192 }
193 
195 {
196  ex->ld->label = 0;
197  train_node(ex, node);
198 
199  bnode_t *par = (bnode_t*) node->parent();
200  while (par != NULL)
201  {
202  if (par->left() == node)
203  ex->ld->label = 0;
204  else
205  ex->ld->label = 1;
206 
207  train_node(ex, par);
208  node = par;
209  par = (bnode_t*) node->parent();
210  }
211 }
212 
214 {
215  CVowpalWabbit *vw = dynamic_cast<CVowpalWabbit*>(m_machines->get_element(node->machine()));
216  ASSERT(vw)
217  float64_t pred = vw->predict_and_finalize(ex);
218  if (ex->ld->label != FLT_MAX)
219  vw->get_learner()->train(ex, ex->eta_round);
220  SG_UNREF(vw);
221  return pred;
222 }
223 
225 {
226  CVowpalWabbit *vw = new CVowpalWabbit(m_feats);
227  vw->set_learner();
228  ex->ld->label = 0;
229  vw->predict_and_finalize(ex);
230  m_machines->push_back(vw);
231  return m_machines->get_num_elements()-1;
232 }

SHOGUN Machine Learning Toolbox - Documentation