SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
VwConditionalProbabilityTree.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #include <vector>
12 #include <stack>
14 
15 using namespace shogun;
16 using namespace std;
17 
19 {
20  if (data)
21  {
22  if (data->get_feature_class() != C_STREAMING_VW)
23  SG_ERROR("Expected StreamingVwFeatures\n");
24  set_features(dynamic_cast<CStreamingVwFeatures*>(data));
25  }
26 
27  vector<int32_t> predicts;
28 
29  m_feats->start_parser();
30  while (m_feats->get_next_example())
31  {
32  predicts.push_back(apply_multiclass_example(m_feats->get_example()));
33  m_feats->release_example();
34  }
35  m_feats->end_parser();
36 
37  CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
38  for (size_t i=0; i < predicts.size(); ++i)
39  labels->set_int_label(i, predicts[i]);
40  return labels;
41 }
42 
44 {
45  ex->ld->label = FLT_MAX; // this will disable VW learning from this example
46 
47  compute_conditional_probabilities(ex);
48  SGVector<float64_t> probs(m_leaves.size());
49  for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
50  {
51  probs[it->first] = accumulate_conditional_probability(it->second);
52  }
53  return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
54 }
55 
57 {
58  stack<node_t *> nodes;
59  nodes.push(m_root);
60 
61  while (!nodes.empty())
62  {
63  node_t *node = nodes.top();
64  nodes.pop();
65  if (node->left())
66  {
67  nodes.push(node->left());
68  nodes.push(node->right());
69 
70  // don't calculate for leaf
71  node->data.p_right = train_node(ex, node);
72  }
73  }
74 }
75 
77 {
78  float64_t prob = 1;
79  node_t *par = leaf->parent();
80  while (par != NULL)
81  {
82  if (leaf == par->left())
83  prob *= (1-par->data.p_right);
84  else
85  prob *= par->data.p_right;
86 
87  leaf = par;
88  par = leaf->parent();
89  }
90 
91  return prob;
92 }
93 
95 {
96  if (data)
97  {
98  if (data->get_feature_class() != C_STREAMING_VW)
99  SG_ERROR("Expected StreamingVwFeatures\n");
100  set_features(dynamic_cast<CStreamingVwFeatures*>(data));
101  }
102  else
103  {
104  if (!m_feats)
105  SG_ERROR("No data features provided\n");
106  }
107 
108  m_machines->reset_array();
109  SG_UNREF(m_root);
110  m_root = NULL;
111 
112  m_leaves.clear();
113 
114  m_feats->start_parser();
115  for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
116  {
117  while (m_feats->get_next_example())
118  {
119  train_example(m_feats->get_example());
120  m_feats->release_example();
121  }
122 
123  if (ipass < m_num_passes-1)
124  m_feats->reset_stream();
125  }
126  m_feats->end_parser();
127 
128  return true;
129 }
130 
132 {
133  int32_t label = static_cast<int32_t>(ex->ld->label);
134 
135  if (m_root == NULL)
136  {
137  m_root = new node_t();
138  m_root->data.label = label;
139  printf(" insert %d %p\n", label, m_root);
140  m_leaves.insert(make_pair(label, m_root));
141  m_root->machine(create_machine(ex));
142  return;
143  }
144 
145  if (m_leaves.find(label) != m_leaves.end())
146  {
147  train_path(ex, m_leaves[label]);
148  }
149  else
150  {
151  node_t *node = m_root;
152  while (node->left() != NULL)
153  {
154  // not a leaf
155  bool is_left = which_subtree(node, ex);
156  if (is_left)
157  ex->ld->label = 0;
158  else
159  ex->ld->label = 1;
160  train_node(ex, node);
161 
162  if (is_left)
163  node = node->left();
164  else
165  node = node->right();
166  }
167 
168  printf(" remove %d %p\n", node->data.label, m_leaves[node->data.label]);
169  m_leaves.erase(node->data.label);
170 
171  node_t *left_node = new node_t();
172  left_node->data.label = node->data.label;
173  node->data.label = -1;
174  CVowpalWabbit *node_vw = dynamic_cast<CVowpalWabbit *>(m_machines->get_element(node->machine()));
175  CVowpalWabbit *vw = new CVowpalWabbit(node_vw);
176  SG_UNREF(node_vw);
177  vw->set_learner();
178  m_machines->push_back(vw);
179  left_node->machine(m_machines->get_num_elements()-1);
180  printf(" insert %d %p\n", left_node->data.label, left_node);
181  m_leaves.insert(make_pair(left_node->data.label, left_node));
182  node->left(left_node);
183 
184  node_t *right_node = new node_t();
185  right_node->data.label = label;
186  right_node->machine(create_machine(ex));
187  printf(" insert %d %p\n", label, right_node);
188  m_leaves.insert(make_pair(label, right_node));
189  node->right(right_node);
190  }
191 }
192 
194 {
195  ex->ld->label = 0;
196  train_node(ex, node);
197 
198  node_t *par = node->parent();
199  while (par != NULL)
200  {
201  if (par->left() == node)
202  ex->ld->label = 0;
203  else
204  ex->ld->label = 1;
205 
206  train_node(ex, par);
207  node = par;
208  par = node->parent();
209  }
210 }
211 
213 {
214  CVowpalWabbit *vw = dynamic_cast<CVowpalWabbit*>(m_machines->get_element(node->machine()));
215  ASSERT(vw);
216  float64_t pred = vw->predict_and_finalize(ex);
217  if (ex->ld->label != FLT_MAX)
218  vw->get_learner()->train(ex, ex->eta_round);
219  SG_UNREF(vw);
220  return pred;
221 }
222 
224 {
225  CVowpalWabbit *vw = new CVowpalWabbit(m_feats);
226  vw->set_learner();
227  ex->ld->label = 0;
228  vw->predict_and_finalize(ex);
229  m_machines->push_back(vw);
230  return m_machines->get_num_elements()-1;
231 }

SHOGUN Machine Learning Toolbox - Documentation