Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <vector>
00012 #include <stack>
00013 #include <shogun/multiclass/tree/VwConditionalProbabilityTree.h>
00014
00015 using namespace shogun;
00016 using namespace std;
00017
00018 CMulticlassLabels* CVwConditionalProbabilityTree::apply_multiclass(CFeatures* data)
00019 {
00020 if (data)
00021 {
00022 if (data->get_feature_class() != C_STREAMING_VW)
00023 SG_ERROR("Expected StreamingVwFeatures\n");
00024 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
00025 }
00026
00027 vector<int32_t> predicts;
00028
00029 m_feats->start_parser();
00030 while (m_feats->get_next_example())
00031 {
00032 predicts.push_back(apply_multiclass_example(m_feats->get_example()));
00033 m_feats->release_example();
00034 }
00035 m_feats->end_parser();
00036
00037 CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
00038 for (size_t i=0; i < predicts.size(); ++i)
00039 labels->set_int_label(i, predicts[i]);
00040 return labels;
00041 }
00042
00043 int32_t CVwConditionalProbabilityTree::apply_multiclass_example(VwExample* ex)
00044 {
00045 ex->ld->label = FLT_MAX;
00046
00047 compute_conditional_probabilities(ex);
00048 SGVector<float64_t> probs(m_leaves.size());
00049 for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
00050 {
00051 probs[it->first] = accumulate_conditional_probability(it->second);
00052 }
00053 return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
00054 }
00055
00056 void CVwConditionalProbabilityTree::compute_conditional_probabilities(VwExample *ex)
00057 {
00058 stack<node_t *> nodes;
00059 nodes.push(m_root);
00060
00061 while (!nodes.empty())
00062 {
00063 node_t *node = nodes.top();
00064 nodes.pop();
00065 if (node->left())
00066 {
00067 nodes.push(node->left());
00068 nodes.push(node->right());
00069
00070
00071 node->data.p_right = train_node(ex, node);
00072 }
00073 }
00074 }
00075
00076 float64_t CVwConditionalProbabilityTree::accumulate_conditional_probability(node_t *leaf)
00077 {
00078 float64_t prob = 1;
00079 node_t *par = leaf->parent();
00080 while (par != NULL)
00081 {
00082 if (leaf == par->left())
00083 prob *= (1-par->data.p_right);
00084 else
00085 prob *= par->data.p_right;
00086
00087 leaf = par;
00088 par = leaf->parent();
00089 }
00090
00091 return prob;
00092 }
00093
00094 bool CVwConditionalProbabilityTree::train_machine(CFeatures* data)
00095 {
00096 if (data)
00097 {
00098 if (data->get_feature_class() != C_STREAMING_VW)
00099 SG_ERROR("Expected StreamingVwFeatures\n");
00100 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
00101 }
00102 else
00103 {
00104 if (!m_feats)
00105 SG_ERROR("No data features provided\n");
00106 }
00107
00108 m_machines->reset_array();
00109 SG_UNREF(m_root);
00110 m_root = NULL;
00111
00112 m_leaves.clear();
00113
00114 m_feats->start_parser();
00115 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
00116 {
00117 while (m_feats->get_next_example())
00118 {
00119 train_example(m_feats->get_example());
00120 m_feats->release_example();
00121 }
00122
00123 if (ipass < m_num_passes-1)
00124 m_feats->reset_stream();
00125 }
00126 m_feats->end_parser();
00127
00128 return true;
00129 }
00130
00131 void CVwConditionalProbabilityTree::train_example(VwExample *ex)
00132 {
00133 int32_t label = static_cast<int32_t>(ex->ld->label);
00134
00135 if (m_root == NULL)
00136 {
00137 m_root = new node_t();
00138 m_root->data.label = label;
00139 printf(" insert %d %p\n", label, m_root);
00140 m_leaves.insert(make_pair(label, m_root));
00141 m_root->machine(create_machine(ex));
00142 return;
00143 }
00144
00145 if (m_leaves.find(label) != m_leaves.end())
00146 {
00147 train_path(ex, m_leaves[label]);
00148 }
00149 else
00150 {
00151 node_t *node = m_root;
00152 while (node->left() != NULL)
00153 {
00154
00155 bool is_left = which_subtree(node, ex);
00156 if (is_left)
00157 ex->ld->label = 0;
00158 else
00159 ex->ld->label = 1;
00160 train_node(ex, node);
00161
00162 if (is_left)
00163 node = node->left();
00164 else
00165 node = node->right();
00166 }
00167
00168 printf(" remove %d %p\n", node->data.label, m_leaves[node->data.label]);
00169 m_leaves.erase(node->data.label);
00170
00171 node_t *left_node = new node_t();
00172 left_node->data.label = node->data.label;
00173 node->data.label = -1;
00174 CVowpalWabbit *node_vw = dynamic_cast<CVowpalWabbit *>(m_machines->get_element(node->machine()));
00175 CVowpalWabbit *vw = new CVowpalWabbit(node_vw);
00176 SG_UNREF(node_vw);
00177 vw->set_learner();
00178 m_machines->push_back(vw);
00179 left_node->machine(m_machines->get_num_elements()-1);
00180 printf(" insert %d %p\n", left_node->data.label, left_node);
00181 m_leaves.insert(make_pair(left_node->data.label, left_node));
00182 node->left(left_node);
00183
00184 node_t *right_node = new node_t();
00185 right_node->data.label = label;
00186 right_node->machine(create_machine(ex));
00187 printf(" insert %d %p\n", label, right_node);
00188 m_leaves.insert(make_pair(label, right_node));
00189 node->right(right_node);
00190 }
00191 }
00192
00193 void CVwConditionalProbabilityTree::train_path(VwExample *ex, node_t *node)
00194 {
00195 ex->ld->label = 0;
00196 train_node(ex, node);
00197
00198 node_t *par = node->parent();
00199 while (par != NULL)
00200 {
00201 if (par->left() == node)
00202 ex->ld->label = 0;
00203 else
00204 ex->ld->label = 1;
00205
00206 train_node(ex, par);
00207 node = par;
00208 par = node->parent();
00209 }
00210 }
00211
00212 float64_t CVwConditionalProbabilityTree::train_node(VwExample *ex, node_t *node)
00213 {
00214 CVowpalWabbit *vw = dynamic_cast<CVowpalWabbit*>(m_machines->get_element(node->machine()));
00215 ASSERT(vw);
00216 float64_t pred = vw->predict_and_finalize(ex);
00217 if (ex->ld->label != FLT_MAX)
00218 vw->get_learner()->train(ex, ex->eta_round);
00219 SG_UNREF(vw);
00220 return pred;
00221 }
00222
00223 int32_t CVwConditionalProbabilityTree::create_machine(VwExample *ex)
00224 {
00225 CVowpalWabbit *vw = new CVowpalWabbit(m_feats);
00226 vw->set_learner();
00227 ex->ld->label = 0;
00228 vw->predict_and_finalize(ex);
00229 m_machines->push_back(vw);
00230 return m_machines->get_num_elements()-1;
00231 }