00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #ifndef CONDITIONALPROBABILITYTREE_H__ 00012 #define CONDITIONALPROBABILITYTREE_H__ 00013 00014 #include <map> 00015 00016 #include <shogun/multiclass/tree/TreeMachine.h> 00017 #include <shogun/classifier/vw/VowpalWabbit.h> 00018 00019 namespace shogun 00020 { 00021 00023 struct VwConditionalProbabilityTreeNodeData 00024 { 00026 int32_t label; 00028 float64_t p_right; 00029 00031 VwConditionalProbabilityTreeNodeData():label(-1), p_right(0) {} 00032 }; 00033 00035 class CVwConditionalProbabilityTree: public CTreeMachine<VwConditionalProbabilityTreeNodeData> 00036 { 00037 public: 00038 typedef CTreeMachineNode<VwConditionalProbabilityTreeNodeData> node_t; 00039 00041 CVwConditionalProbabilityTree(int32_t num_passes=1) 00042 :m_num_passes(num_passes), m_feats(NULL) 00043 { 00044 } 00045 00047 virtual ~CVwConditionalProbabilityTree() {} 00048 00050 virtual const char* get_name() const { return "VwConditionalProbabilityTree"; } 00051 00053 void set_num_passes(int32_t num_passes) 00054 { 00055 m_num_passes = num_passes; 00056 } 00057 00059 int32_t get_num_passes() const 00060 { 00061 return m_num_passes; 00062 } 00063 00067 void set_features(CStreamingVwFeatures *feats) 00068 { 00069 SG_REF(feats); 00070 SG_UNREF(m_feats); 00071 m_feats = feats; 00072 } 00073 00075 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00076 00078 virtual int32_t apply_multiclass_example(VwExample* ex); 00079 protected: 00081 virtual bool train_require_labels() const { return false; } 00082 00089 virtual bool train_machine(CFeatures* data); 00090 00094 void train_example(VwExample *ex); 00095 00100 void train_path(VwExample *ex, node_t *node); 00101 00107 float64_t train_node(VwExample *ex, node_t *node); 00108 00112 int32_t create_machine(VwExample *ex); 00113 00119 virtual bool which_subtree(node_t *node, VwExample *ex)=0; 00120 00122 void compute_conditional_probabilities(VwExample *ex); 00123 00127 float64_t accumulate_conditional_probability(node_t *leaf); 00128 00129 int32_t m_num_passes; 00130 std::map<int32_t, node_t*> m_leaves; 00131 CStreamingVwFeatures *m_feats; 00132 }; 00133 00134 } /* shogun */ 00135 00136 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */ 00137