VwConditionalProbabilityTree.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <vector>
00012 #include <stack>
00013 #include <shogun/multiclass/tree/VwConditionalProbabilityTree.h>
00014 
00015 using namespace shogun;
00016 using namespace std;
00017 
00018 CMulticlassLabels* CVwConditionalProbabilityTree::apply_multiclass(CFeatures* data)
00019 {
00020     if (data)
00021     {
00022         if (data->get_feature_class() != C_STREAMING_VW)
00023             SG_ERROR("Expected StreamingVwFeatures\n");
00024         set_features(dynamic_cast<CStreamingVwFeatures*>(data));
00025     }
00026 
00027     vector<int32_t> predicts;
00028 
00029     m_feats->start_parser();
00030     while (m_feats->get_next_example())
00031     {
00032         predicts.push_back(apply_multiclass_example(m_feats->get_example()));
00033         m_feats->release_example();
00034     }
00035     m_feats->end_parser();
00036 
00037     CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
00038     for (size_t i=0; i < predicts.size(); ++i)
00039         labels->set_int_label(i, predicts[i]);
00040     return labels;
00041 }
00042 
00043 int32_t CVwConditionalProbabilityTree::apply_multiclass_example(VwExample* ex)
00044 {
00045     ex->ld->label = FLT_MAX; // this will disable VW learning from this example
00046 
00047     compute_conditional_probabilities(ex);
00048     SGVector<float64_t> probs(m_leaves.size());
00049     for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
00050     {
00051         probs[it->first] = accumulate_conditional_probability(it->second);
00052     }
00053     return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
00054 }
00055 
00056 void CVwConditionalProbabilityTree::compute_conditional_probabilities(VwExample *ex)
00057 {
00058     stack<node_t *> nodes;
00059     nodes.push(m_root);
00060 
00061     while (!nodes.empty())
00062     {
00063         node_t *node = nodes.top();
00064         nodes.pop();
00065         if (node->left())
00066         {
00067             nodes.push(node->left());
00068             nodes.push(node->right());
00069 
00070             // don't calculate for leaf
00071             node->data.p_right = train_node(ex, node);
00072         }
00073     }
00074 }
00075 
00076 float64_t CVwConditionalProbabilityTree::accumulate_conditional_probability(node_t *leaf)
00077 {
00078     float64_t prob = 1;
00079     node_t *par = leaf->parent();
00080     while (par != NULL)
00081     {
00082         if (leaf == par->left())
00083             prob *= (1-par->data.p_right);
00084         else
00085             prob *= par->data.p_right;
00086 
00087         leaf = par;
00088         par = leaf->parent();
00089     }
00090 
00091     return prob;
00092 }
00093 
00094 bool CVwConditionalProbabilityTree::train_machine(CFeatures* data)
00095 {
00096     if (data)
00097     {
00098         if (data->get_feature_class() != C_STREAMING_VW)
00099             SG_ERROR("Expected StreamingVwFeatures\n");
00100         set_features(dynamic_cast<CStreamingVwFeatures*>(data));
00101     }
00102     else
00103     {
00104         if (!m_feats)
00105             SG_ERROR("No data features provided\n");
00106     }
00107 
00108     m_machines->reset_array();
00109     SG_UNREF(m_root);
00110     m_root = NULL;
00111 
00112     m_leaves.clear();
00113 
00114     m_feats->start_parser();
00115     for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
00116     {
00117         while (m_feats->get_next_example())
00118         {
00119             train_example(m_feats->get_example());
00120             m_feats->release_example();
00121         }
00122 
00123         if (ipass < m_num_passes-1)
00124             m_feats->reset_stream();
00125     }
00126     m_feats->end_parser();
00127 
00128     return true;
00129 }
00130 
00131 void CVwConditionalProbabilityTree::train_example(VwExample *ex)
00132 {
00133     int32_t label = static_cast<int32_t>(ex->ld->label);
00134 
00135     if (m_root == NULL)
00136     {
00137         m_root = new node_t();
00138         m_root->data.label = label;
00139         printf("  insert %d %p\n", label, m_root);
00140         m_leaves.insert(make_pair(label, m_root));
00141         m_root->machine(create_machine(ex));
00142         return;
00143     }
00144 
00145     if (m_leaves.find(label) != m_leaves.end())
00146     {
00147         train_path(ex, m_leaves[label]);
00148     }
00149     else
00150     {
00151         node_t *node = m_root;
00152         while (node->left() != NULL)
00153         {
00154             // not a leaf
00155             bool is_left = which_subtree(node, ex);
00156             if (is_left)
00157                 ex->ld->label = 0;
00158             else
00159                 ex->ld->label = 1;
00160             train_node(ex, node);
00161 
00162             if (is_left)
00163                 node = node->left();
00164             else
00165                 node = node->right();
00166         }
00167 
00168         printf("  remove %d %p\n", node->data.label, m_leaves[node->data.label]);
00169         m_leaves.erase(node->data.label);
00170 
00171         node_t *left_node = new node_t();
00172         left_node->data.label = node->data.label;
00173         node->data.label = -1;
00174         CVowpalWabbit *node_vw = dynamic_cast<CVowpalWabbit *>(m_machines->get_element(node->machine()));
00175         CVowpalWabbit *vw = new CVowpalWabbit(node_vw);
00176         SG_UNREF(node_vw);
00177         vw->set_learner();
00178         m_machines->push_back(vw);
00179         left_node->machine(m_machines->get_num_elements()-1);
00180         printf("  insert %d %p\n", left_node->data.label, left_node);
00181         m_leaves.insert(make_pair(left_node->data.label, left_node));
00182         node->left(left_node);
00183 
00184         node_t *right_node = new node_t();
00185         right_node->data.label = label;
00186         right_node->machine(create_machine(ex));
00187         printf("  insert %d %p\n", label, right_node);
00188         m_leaves.insert(make_pair(label, right_node));
00189         node->right(right_node);
00190     }
00191 }
00192 
00193 void CVwConditionalProbabilityTree::train_path(VwExample *ex, node_t *node)
00194 {
00195     ex->ld->label = 0;
00196     train_node(ex, node);
00197 
00198     node_t *par = node->parent();
00199     while (par != NULL)
00200     {
00201         if (par->left() == node)
00202             ex->ld->label = 0;
00203         else
00204             ex->ld->label = 1;
00205 
00206         train_node(ex, par);
00207         node = par;
00208         par = node->parent();
00209     }
00210 }
00211 
00212 float64_t CVwConditionalProbabilityTree::train_node(VwExample *ex, node_t *node)
00213 {
00214     CVowpalWabbit *vw = dynamic_cast<CVowpalWabbit*>(m_machines->get_element(node->machine()));
00215     ASSERT(vw);
00216     float64_t pred = vw->predict_and_finalize(ex);
00217     if (ex->ld->label != FLT_MAX)
00218         vw->get_learner()->train(ex, ex->eta_round);
00219     SG_UNREF(vw);
00220     return pred;
00221 }
00222 
00223 int32_t CVwConditionalProbabilityTree::create_machine(VwExample *ex)
00224 {
00225     CVowpalWabbit *vw = new CVowpalWabbit(m_feats);
00226     vw->set_learner();
00227     ex->ld->label = 0;
00228     vw->predict_and_finalize(ex);
00229     m_machines->push_back(vw);
00230     return m_machines->get_num_elements()-1;
00231 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation