00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #ifndef CONDITIONALPROBABILITYTREE_H__ 00012 #define CONDITIONALPROBABILITYTREE_H__ 00013 00014 #include <map> 00015 00016 #include <shogun/features/streaming/StreamingDenseFeatures.h> 00017 #include <shogun/multiclass/tree/TreeMachine.h> 00018 #include <shogun/multiclass/tree/ConditionalProbabilityTreeNodeData.h> 00019 00020 namespace shogun 00021 { 00022 00031 class CConditionalProbabilityTree: public CTreeMachine<ConditionalProbabilityTreeNodeData> 00032 { 00033 public: 00035 CConditionalProbabilityTree(int32_t num_passes=1) 00036 :m_num_passes(num_passes), m_feats(NULL) 00037 { 00038 } 00039 00041 virtual ~CConditionalProbabilityTree() { SG_UNREF(m_feats); } 00042 00044 virtual const char* get_name() const { return "ConditionalProbabilityTree"; } 00045 00047 void set_num_passes(int32_t num_passes) 00048 { 00049 m_num_passes = num_passes; 00050 } 00051 00053 int32_t get_num_passes() const 00054 { 00055 return m_num_passes; 00056 } 00057 00061 void set_features(CStreamingDenseFeatures<float32_t> *feats) 00062 { 00063 SG_REF(feats); 00064 SG_UNREF(m_feats); 00065 m_feats = feats; 00066 } 00067 00069 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00070 00074 virtual int32_t apply_multiclass_example(SGVector<float32_t> ex); 00075 00077 void print_tree(); 00078 protected: 00080 virtual bool train_require_labels() const { return false; } 00081 00088 virtual bool train_machine(CFeatures* data); 00089 00094 void train_example(SGVector<float32_t> ex, int32_t label); 00095 00100 void train_path(SGVector<float32_t> ex, node_t *node); 00101 00107 void train_node(SGVector<float32_t> ex, float64_t label, node_t *node); 00108 00113 float64_t predict_node(SGVector<float32_t> ex, node_t *node); 00114 00118 int32_t create_machine(SGVector<float32_t> ex); 00119 00125 virtual bool which_subtree(node_t *node, SGVector<float32_t> ex)=0; 00126 00128 void compute_conditional_probabilities(SGVector<float32_t> ex); 00129 00133 float64_t accumulate_conditional_probability(node_t *leaf); 00134 00135 int32_t m_num_passes; 00136 std::map<int32_t, node_t*> m_leaves; 00137 CStreamingDenseFeatures<float32_t> *m_feats; 00138 }; 00139 00140 } /* shogun */ 00141 00142 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */ 00143