SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ConditionalProbabilityTree.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #include <vector>
12 #include <stack>
13 
16 
17 using namespace shogun;
18 using namespace std;
19 
21 {
22  if (data)
23  {
24  if (data->get_feature_class() != C_STREAMING_DENSE)
25  SG_ERROR("Expected StreamingDenseFeatures\n")
26  if (data->get_feature_type() != F_SHORTREAL)
27  SG_ERROR("Expected float32_t feature type\n")
28 
29  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t>* >(data));
30  }
31 
32  vector<int32_t> predicts;
33 
34  m_feats->start_parser();
35  while (m_feats->get_next_example())
36  {
37  predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
38  m_feats->release_example();
39  }
40  m_feats->end_parser();
41 
42  CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
43  for (size_t i=0; i < predicts.size(); ++i)
44  labels->set_int_label(i, predicts[i]);
45  return labels;
46 }
47 
49 {
50  compute_conditional_probabilities(ex);
51  SGVector<float64_t> probs(m_leaves.size());
52  for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
53  {
54  probs[it->first] = accumulate_conditional_probability(it->second);
55  }
56  return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
57 }
58 
60 {
61  stack<node_t *> nodes;
62  nodes.push(m_root);
63 
64  while (!nodes.empty())
65  {
66  node_t *node = nodes.top();
67  nodes.pop();
68  if (node->left())
69  {
70  nodes.push(node->left());
71  nodes.push(node->right());
72 
73  // don't calculate for leaf
74  node->data.p_right = predict_node(ex, node);
75  }
76  }
77 }
78 
80 {
81  float64_t prob = 1;
82  node_t *par = leaf->parent();
83  while (par != NULL)
84  {
85  if (leaf == par->left())
86  prob *= (1-par->data.p_right);
87  else
88  prob *= par->data.p_right;
89 
90  leaf = par;
91  par = leaf->parent();
92  }
93 
94  return prob;
95 }
96 
98 {
99  if (data)
100  {
101  if (data->get_feature_class() != C_STREAMING_DENSE)
102  SG_ERROR("Expected StreamingDenseFeatures\n")
103  if (data->get_feature_type() != F_SHORTREAL)
104  SG_ERROR("Expected float32_t features\n")
105  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t> *>(data));
106  }
107  else
108  {
109  if (!m_feats)
110  SG_ERROR("No data features provided\n")
111  }
112 
113  m_machines->reset_array();
114  SG_UNREF(m_root);
115  m_root = NULL;
116 
117  m_leaves.clear();
118 
119  m_feats->start_parser();
120  for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
121  {
122  while (m_feats->get_next_example())
123  {
124  train_example(m_feats->get_vector(), static_cast<int32_t>(m_feats->get_label()));
125  m_feats->release_example();
126  }
127 
128  if (ipass < m_num_passes-1)
129  m_feats->reset_stream();
130  }
131  m_feats->end_parser();
132 
133  for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
134  {
135  COnlineLibLinear *lll = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(i));
136  lll->stop_train();
137  SG_UNREF(lll);
138  }
139 
140  return true;
141 }
142 
144 {
145  if (m_root)
147  else
148  printf("Empty Tree\n");
149 }
150 
152 {
153  if (m_root == NULL)
154  {
155  m_root = new node_t();
156  m_root->data.label = label;
157  m_leaves.insert(make_pair(label, m_root));
158  m_root->machine(create_machine(ex));
159  return;
160  }
161 
162  if (m_leaves.find(label) != m_leaves.end())
163  {
164  train_path(ex, m_leaves[label]);
165  }
166  else
167  {
168  node_t *node = m_root;
169  while (node->left() != NULL)
170  {
171  // not a leaf
172  bool is_left = which_subtree(node, ex);
173  float64_t node_label;
174  if (is_left)
175  node_label = 0;
176  else
177  node_label = 1;
178  train_node(ex, node_label, node);
179 
180  if (is_left)
181  node = node->left();
182  else
183  node = node->right();
184  }
185 
186  m_leaves.erase(node->data.label);
187 
188  node_t *left_node = new node_t();
189  left_node->data.label = node->data.label;
190  node->data.label = -1;
191  COnlineLibLinear *node_mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
192  COnlineLibLinear *mch = new COnlineLibLinear(node_mch);
193  SG_UNREF(node_mch);
194  mch->start_train();
195  m_machines->push_back(mch);
196  left_node->machine(m_machines->get_num_elements()-1);
197  m_leaves.insert(make_pair(left_node->data.label, left_node));
198  node->left(left_node);
199 
200  node_t *right_node = new node_t();
201  right_node->data.label = label;
202  right_node->machine(create_machine(ex));
203  m_leaves.insert(make_pair(label, right_node));
204  node->right(right_node);
205  }
206 }
207 
209 {
210  float64_t node_label = 0;
211  train_node(ex, node_label, node);
212 
213  node_t *par = node->parent();
214  while (par != NULL)
215  {
216  if (par->left() == node)
217  node_label = 0;
218  else
219  node_label = 1;
220 
221  train_node(ex, node_label, par);
222  node = par;
223  par = node->parent();
224  }
225 }
226 
228 {
229  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
230  ASSERT(mch)
231  mch->train_one(ex, label);
232  SG_UNREF(mch);
233 }
234 
236 {
237  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
238  ASSERT(mch)
239  float64_t pred = mch->apply_one(ex.vector, ex.vlen);
240  SG_UNREF(mch);
241  // use sigmoid function to turn the decision value into valid probability
242  return 1.0/(1+CMath::exp(-pred));
243 }
244 
246 {
247  COnlineLibLinear *mch = new COnlineLibLinear();
248  mch->start_train();
249  mch->train_one(ex, 0);
250  m_machines->push_back(mch);
251  return m_machines->get_num_elements()-1;
252 }

SHOGUN Machine Learning Toolbox - Documentation