SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ConditionalProbabilityTree.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #include <vector>
12 #include <stack>
13 
17 
18 using namespace shogun;
19 using namespace std;
20 
22 {
23  if (data)
24  {
25  if (data->get_feature_class() != C_STREAMING_DENSE)
26  SG_ERROR("Expected StreamingDenseFeatures\n")
27  if (data->get_feature_type() != F_SHORTREAL)
28  SG_ERROR("Expected float32_t feature type\n")
29 
30  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t>* >(data));
31  }
32 
33  vector<int32_t> predicts;
34 
35  m_feats->start_parser();
36  while (m_feats->get_next_example())
37  {
38  predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
39  m_feats->release_example();
40  }
41  m_feats->end_parser();
42 
43  CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
44  for (size_t i=0; i < predicts.size(); ++i)
45  labels->set_int_label(i, predicts[i]);
46  return labels;
47 }
48 
50 {
51  compute_conditional_probabilities(ex);
52  SGVector<float64_t> probs(m_leaves.size());
53  for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
54  {
55  probs[it->first] = accumulate_conditional_probability(it->second);
56  }
57  return CMath::arg_max(probs.vector, 1, probs.vlen);
58 }
59 
61 {
62  stack<bnode_t *> nodes;
63  nodes.push((bnode_t*) m_root);
64 
65  while (!nodes.empty())
66  {
67  bnode_t *node = nodes.top();
68  nodes.pop();
69  if (node->left())
70  {
71  nodes.push(node->left());
72  nodes.push(node->right());
73 
74  // don't calculate for leaf
75  node->data.p_right = predict_node(ex, node);
76  }
77  }
78 }
79 
81 {
82  float64_t prob = 1;
83  bnode_t *par = (bnode_t*) leaf->parent();
84  while (par != NULL)
85  {
86  if (leaf == par->left())
87  prob *= (1-par->data.p_right);
88  else
89  prob *= par->data.p_right;
90 
91  leaf = par;
92  par = (bnode_t*) leaf->parent();
93  }
94 
95  return prob;
96 }
97 
99 {
100  if (data)
101  {
102  if (data->get_feature_class() != C_STREAMING_DENSE)
103  SG_ERROR("Expected StreamingDenseFeatures\n")
104  if (data->get_feature_type() != F_SHORTREAL)
105  SG_ERROR("Expected float32_t features\n")
106  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t> *>(data));
107  }
108  else
109  {
110  if (!m_feats)
111  SG_ERROR("No data features provided\n")
112  }
113 
114  m_machines->reset_array();
115  SG_UNREF(m_root);
116  m_root = NULL;
117 
118  m_leaves.clear();
119 
120  m_feats->start_parser();
121  for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
122  {
123  while (m_feats->get_next_example())
124  {
125  train_example(m_feats->get_vector(), static_cast<int32_t>(m_feats->get_label()));
126  m_feats->release_example();
127  }
128 
129  if (ipass < m_num_passes-1)
130  m_feats->reset_stream();
131  }
132  m_feats->end_parser();
133 
134  for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
135  {
136  COnlineLibLinear *lll = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(i));
137  lll->stop_train();
138  SG_UNREF(lll);
139  }
140 
141  return true;
142 }
143 
145 {
146  if (m_root)
148  else
149  printf("Empty Tree\n");
150 }
151 
153 {
154  if (m_root == NULL)
155  {
156  m_root = new bnode_t();
157  m_root->data.label = label;
158  m_leaves.insert(make_pair(label, (bnode_t*) m_root));
159  m_root->machine(create_machine(ex));
160  return;
161  }
162 
163  if (m_leaves.find(label) != m_leaves.end())
164  {
165  train_path(ex, m_leaves[label]);
166  }
167  else
168  {
169  bnode_t *node = (bnode_t*) m_root;
170  while (node->left() != NULL)
171  {
172  // not a leaf
173  bool is_left = which_subtree(node, ex);
174  float64_t node_label;
175  if (is_left)
176  node_label = 0;
177  else
178  node_label = 1;
179  train_node(ex, node_label, node);
180 
181  if (is_left)
182  node = node->left();
183  else
184  node = node->right();
185  }
186 
187  m_leaves.erase(node->data.label);
188 
189  bnode_t *left_node = new bnode_t();
190  left_node->data.label = node->data.label;
191  node->data.label = -1;
192  COnlineLibLinear *node_mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
193  COnlineLibLinear *mch = new COnlineLibLinear(node_mch);
194  SG_UNREF(node_mch);
195  mch->start_train();
196  m_machines->push_back(mch);
197  left_node->machine(m_machines->get_num_elements()-1);
198  m_leaves.insert(make_pair(left_node->data.label, left_node));
199  node->left(left_node);
200 
201  bnode_t *right_node = new bnode_t();
202  right_node->data.label = label;
203  right_node->machine(create_machine(ex));
204  m_leaves.insert(make_pair(label, right_node));
205  node->right(right_node);
206  }
207 }
208 
210 {
211  float64_t node_label = 0;
212  train_node(ex, node_label, node);
213 
214  bnode_t *par = (bnode_t*) node->parent();
215  while (par != NULL)
216  {
217  if (par->left() == node)
218  node_label = 0;
219  else
220  node_label = 1;
221 
222  train_node(ex, node_label, par);
223  node = par;
224  par = (bnode_t*) node->parent();
225  }
226 }
227 
229 {
230  REQUIRE(node, "Node must not be NULL\n");
231  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
232  REQUIRE(mch, "Instance of %s could not be casted to COnlineLibLinear\n", node->get_name());
233  mch->train_one(ex, label);
234  SG_UNREF(mch);
235 }
236 
238 {
239  REQUIRE(node, "Node must not be NULL\n");
240  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
241  REQUIRE(mch, "Instance of %s could not be casted to COnlineLibLinear\n", node->get_name());
242  float64_t pred = mch->apply_one(ex.vector, ex.vlen);
243  SG_UNREF(mch);
244  // use sigmoid function to turn the decision value into valid probability
245  return 1.0/(1+CMath::exp(-pred));
246 }
247 
249 {
250  COnlineLibLinear *mch = new COnlineLibLinear();
251  mch->start_train();
252  mch->train_one(ex, 0);
253  m_machines->push_back(mch);
254  return m_machines->get_num_elements()-1;
255 }

SHOGUN Machine Learning Toolbox - Documentation