ConditionalProbabilityTree.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <vector>
00012 #include <stack>
00013 
00014 #include <shogun/multiclass/tree/ConditionalProbabilityTree.h>
00015 #include <shogun/classifier/svm/OnlineLibLinear.h>
00016 
00017 using namespace shogun;
00018 using namespace std;
00019 
00020 CMulticlassLabels* CConditionalProbabilityTree::apply_multiclass(CFeatures* data)
00021 {
00022     if (data)
00023     {
00024         if (data->get_feature_class() != C_STREAMING_DENSE)
00025             SG_ERROR("Expected StreamingDenseFeatures\n");
00026         if (data->get_feature_type() != F_SHORTREAL)
00027             SG_ERROR("Expected float32_t feature type\n");
00028 
00029         set_features(dynamic_cast<CStreamingDenseFeatures<float32_t>* >(data));
00030     }
00031 
00032     vector<int32_t> predicts;
00033 
00034     m_feats->start_parser();
00035     while (m_feats->get_next_example())
00036     {
00037         predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
00038         m_feats->release_example();
00039     }
00040     m_feats->end_parser();
00041 
00042     CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
00043     for (size_t i=0; i < predicts.size(); ++i)
00044         labels->set_int_label(i, predicts[i]);
00045     return labels;
00046 }
00047 
00048 int32_t CConditionalProbabilityTree::apply_multiclass_example(SGVector<float32_t> ex)
00049 {
00050     compute_conditional_probabilities(ex);
00051     SGVector<float64_t> probs(m_leaves.size());
00052     for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
00053     {
00054         probs[it->first] = accumulate_conditional_probability(it->second);
00055     }
00056     return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
00057 }
00058 
00059 void CConditionalProbabilityTree::compute_conditional_probabilities(SGVector<float32_t> ex)
00060 {
00061     stack<node_t *> nodes;
00062     nodes.push(m_root);
00063 
00064     while (!nodes.empty())
00065     {
00066         node_t *node = nodes.top();
00067         nodes.pop();
00068         if (node->left())
00069         {
00070             nodes.push(node->left());
00071             nodes.push(node->right());
00072 
00073             // don't calculate for leaf
00074             node->data.p_right = predict_node(ex, node);
00075         }
00076     }
00077 }
00078 
00079 float64_t CConditionalProbabilityTree::accumulate_conditional_probability(node_t *leaf)
00080 {
00081     float64_t prob = 1;
00082     node_t *par = leaf->parent();
00083     while (par != NULL)
00084     {
00085         if (leaf == par->left())
00086             prob *= (1-par->data.p_right);
00087         else
00088             prob *= par->data.p_right;
00089 
00090         leaf = par;
00091         par = leaf->parent();
00092     }
00093 
00094     return prob;
00095 }
00096 
00097 bool CConditionalProbabilityTree::train_machine(CFeatures* data)
00098 {
00099     if (data)
00100     {
00101         if (data->get_feature_class() != C_STREAMING_DENSE)
00102             SG_ERROR("Expected StreamingDenseFeatures\n");
00103         if (data->get_feature_type() != F_SHORTREAL)
00104             SG_ERROR("Expected float32_t features\n");
00105         set_features(dynamic_cast<CStreamingDenseFeatures<float32_t> *>(data));
00106     }
00107     else
00108     {
00109         if (!m_feats)
00110             SG_ERROR("No data features provided\n");
00111     }
00112 
00113     m_machines->reset_array();
00114     SG_UNREF(m_root);
00115     m_root = NULL;
00116 
00117     m_leaves.clear();
00118 
00119     m_feats->start_parser();
00120     for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
00121     {
00122         while (m_feats->get_next_example())
00123         {
00124             train_example(m_feats->get_vector(), static_cast<int32_t>(m_feats->get_label()));
00125             m_feats->release_example();
00126         }
00127 
00128         if (ipass < m_num_passes-1)
00129             m_feats->reset_stream();
00130     }
00131     m_feats->end_parser();
00132 
00133     for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
00134     {
00135         COnlineLibLinear *lll = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(i));
00136         lll->stop_train();
00137         SG_UNREF(lll);
00138     }
00139 
00140     return true;
00141 }
00142 
00143 void CConditionalProbabilityTree::print_tree()
00144 {
00145     if (m_root)
00146         m_root->debug_print(ConditionalProbabilityTreeNodeData::print_data);
00147     else
00148         printf("Empty Tree\n");
00149 }
00150 
00151 void CConditionalProbabilityTree::train_example(SGVector<float32_t> ex, int32_t label)
00152 {
00153     if (m_root == NULL)
00154     {
00155         m_root = new node_t();
00156         m_root->data.label = label;
00157         m_leaves.insert(make_pair(label, m_root));
00158         m_root->machine(create_machine(ex));
00159         return;
00160     }
00161 
00162     if (m_leaves.find(label) != m_leaves.end())
00163     {
00164         train_path(ex, m_leaves[label]);
00165     }
00166     else
00167     {
00168         node_t *node = m_root;
00169         while (node->left() != NULL)
00170         {
00171             // not a leaf
00172             bool is_left = which_subtree(node, ex);
00173             float64_t node_label;
00174             if (is_left)
00175                 node_label = 0;
00176             else
00177                 node_label = 1;
00178             train_node(ex, node_label, node);
00179 
00180             if (is_left)
00181                 node = node->left();
00182             else
00183                 node = node->right();
00184         }
00185 
00186         m_leaves.erase(node->data.label);
00187 
00188         node_t *left_node = new node_t();
00189         left_node->data.label = node->data.label;
00190         node->data.label = -1;
00191         COnlineLibLinear *node_mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
00192         COnlineLibLinear *mch = new COnlineLibLinear(node_mch);
00193         SG_UNREF(node_mch);
00194         mch->start_train();
00195         m_machines->push_back(mch);
00196         left_node->machine(m_machines->get_num_elements()-1);
00197         m_leaves.insert(make_pair(left_node->data.label, left_node));
00198         node->left(left_node);
00199 
00200         node_t *right_node = new node_t();
00201         right_node->data.label = label;
00202         right_node->machine(create_machine(ex));
00203         m_leaves.insert(make_pair(label, right_node));
00204         node->right(right_node);
00205     }
00206 }
00207 
00208 void CConditionalProbabilityTree::train_path(SGVector<float32_t> ex, node_t *node)
00209 {
00210     float64_t node_label = 0;
00211     train_node(ex, node_label, node);
00212 
00213     node_t *par = node->parent();
00214     while (par != NULL)
00215     {
00216         if (par->left() == node)
00217             node_label = 0;
00218         else
00219             node_label = 1;
00220 
00221         train_node(ex, node_label, par);
00222         node = par;
00223         par = node->parent();
00224     }
00225 }
00226 
00227 void CConditionalProbabilityTree::train_node(SGVector<float32_t> ex, float64_t label, node_t *node)
00228 {
00229     COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
00230     ASSERT(mch);
00231     mch->train_one(ex, label);
00232     SG_UNREF(mch);
00233 }
00234 
00235 float64_t CConditionalProbabilityTree::predict_node(SGVector<float32_t> ex, node_t *node)
00236 {
00237     COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
00238     ASSERT(mch);
00239     float64_t pred = mch->apply_one(ex.vector, ex.vlen);
00240     SG_UNREF(mch);
00241     // use sigmoid function to turn the decision value into valid probability
00242     return 1.0/(1+CMath::exp(-pred));
00243 }
00244 
00245 int32_t CConditionalProbabilityTree::create_machine(SGVector<float32_t> ex)
00246 {
00247     COnlineLibLinear *mch = new COnlineLibLinear();
00248     mch->start_train();
00249     mch->train_one(ex, 0);
00250     m_machines->push_back(mch);
00251     return m_machines->get_num_elements()-1;
00252 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation