ConditionalProbabilityTree.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #ifndef CONDITIONALPROBABILITYTREE_H__
00012 #define CONDITIONALPROBABILITYTREE_H__
00013 
00014 #include <map>
00015 
00016 #include <shogun/features/streaming/StreamingDenseFeatures.h>
00017 #include <shogun/multiclass/tree/TreeMachine.h>
00018 #include <shogun/multiclass/tree/ConditionalProbabilityTreeNodeData.h>
00019 
00020 namespace shogun
00021 {
00022 
00031 class CConditionalProbabilityTree: public CTreeMachine<ConditionalProbabilityTreeNodeData>
00032 {
00033 public:
00035     CConditionalProbabilityTree(int32_t num_passes=1)
00036         :m_num_passes(num_passes), m_feats(NULL)
00037     {
00038     }
00039 
00041     virtual ~CConditionalProbabilityTree() { SG_UNREF(m_feats); }
00042 
00044     virtual const char* get_name() const { return "ConditionalProbabilityTree"; }
00045 
00047     void set_num_passes(int32_t num_passes)
00048     {
00049         m_num_passes = num_passes;
00050     }
00051 
00053     int32_t get_num_passes() const
00054     {
00055         return m_num_passes;
00056     }
00057 
00061     void set_features(CStreamingDenseFeatures<float32_t> *feats)
00062     {
00063         SG_REF(feats);
00064         SG_UNREF(m_feats);
00065         m_feats = feats;
00066     }
00067 
00069     virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00070 
00074     virtual int32_t apply_multiclass_example(SGVector<float32_t> ex);
00075 
00077     void print_tree();
00078 protected:
00080     virtual bool train_require_labels() const { return false; }
00081 
00088     virtual bool train_machine(CFeatures* data);
00089 
00094     void train_example(SGVector<float32_t> ex, int32_t label);
00095 
00100     void train_path(SGVector<float32_t> ex, node_t *node);
00101 
00107     void train_node(SGVector<float32_t> ex, float64_t label, node_t *node);
00108 
00113     float64_t predict_node(SGVector<float32_t> ex, node_t *node);
00114 
00118     int32_t create_machine(SGVector<float32_t> ex);
00119 
00125     virtual bool which_subtree(node_t *node, SGVector<float32_t> ex)=0;
00126 
00128     void compute_conditional_probabilities(SGVector<float32_t> ex);
00129 
00133     float64_t accumulate_conditional_probability(node_t *leaf);
00134 
00135     int32_t m_num_passes; 
00136     std::map<int32_t, node_t*> m_leaves; 
00137     CStreamingDenseFeatures<float32_t> *m_feats; 
00138 };
00139 
00140 } /* shogun */ 
00141 
00142 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
00143 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation