VwConditionalProbabilityTree.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #ifndef CONDITIONALPROBABILITYTREE_H__
00012 #define CONDITIONALPROBABILITYTREE_H__
00013 
00014 #include <map>
00015 
00016 #include <shogun/multiclass/tree/TreeMachine.h>
00017 #include <shogun/classifier/vw/VowpalWabbit.h>
00018 
00019 namespace shogun
00020 {
00021 
00023 struct VwConditionalProbabilityTreeNodeData
00024 {
00026     int32_t label;
00028     float64_t p_right; 
00029 
00031     VwConditionalProbabilityTreeNodeData():label(-1), p_right(0) {}
00032 };
00033 
00035 class CVwConditionalProbabilityTree: public CTreeMachine<VwConditionalProbabilityTreeNodeData>
00036 {
00037 public:
00038     typedef CTreeMachineNode<VwConditionalProbabilityTreeNodeData> node_t;
00039 
00041     CVwConditionalProbabilityTree(int32_t num_passes=1)
00042         :m_num_passes(num_passes), m_feats(NULL)
00043     {
00044     }
00045 
00047     virtual ~CVwConditionalProbabilityTree() {}
00048 
00050     virtual const char* get_name() const { return "VwConditionalProbabilityTree"; }
00051 
00053     void set_num_passes(int32_t num_passes)
00054     {
00055         m_num_passes = num_passes;
00056     }
00057 
00059     int32_t get_num_passes() const
00060     {
00061         return m_num_passes;
00062     }
00063 
00067     void set_features(CStreamingVwFeatures *feats)
00068     {
00069         SG_REF(feats);
00070         SG_UNREF(m_feats);
00071         m_feats = feats;
00072     }
00073 
00075     virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00076 
00078     virtual int32_t apply_multiclass_example(VwExample* ex);
00079 protected:
00081     virtual bool train_require_labels() const { return false; }
00082 
00089     virtual bool train_machine(CFeatures* data);
00090 
00094     void train_example(VwExample *ex);
00095 
00100     void train_path(VwExample *ex, node_t *node);
00101 
00107     float64_t train_node(VwExample *ex, node_t *node);
00108 
00112     int32_t create_machine(VwExample *ex);
00113 
00119     virtual bool which_subtree(node_t *node, VwExample *ex)=0;
00120 
00122     void compute_conditional_probabilities(VwExample *ex);
00123 
00127     float64_t accumulate_conditional_probability(node_t *leaf);
00128 
00129     int32_t m_num_passes; 
00130     std::map<int32_t, node_t*> m_leaves; 
00131     CStreamingVwFeatures *m_feats; 
00132 };
00133 
00134 } /* shogun */ 
00135 
00136 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
00137 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation