Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef RELAXEDTREE_H__
00012 #define RELAXEDTREE_H__
00013
00014 #include <utility>
00015 #include <vector>
00016
00017 #include <shogun/features/DenseFeatures.h>
00018 #include <shogun/classifier/svm/LibSVM.h>
00019 #include <shogun/multiclass/tree/TreeMachine.h>
00020 #include <shogun/multiclass/tree/RelaxedTreeNodeData.h>
00021
00022 namespace shogun
00023 {
00024
00025 class CBaseMulticlassMachine;
00026
00034 class CRelaxedTree: public CTreeMachine<RelaxedTreeNodeData>
00035 {
00036 public:
00038 CRelaxedTree();
00039
00041 virtual ~CRelaxedTree();
00042
00044 virtual const char* get_name() const { return "RelaxedTree"; }
00045
00047 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00048
00052 void set_features(CDenseFeatures<float64_t> *feats)
00053 {
00054 SG_REF(feats);
00055 SG_UNREF(m_feats);
00056 m_feats = feats;
00057 }
00058
00062 virtual void set_kernel(CKernel *kernel)
00063 {
00064 SG_REF(kernel);
00065 SG_UNREF(m_kernel);
00066 m_kernel = kernel;
00067 }
00068
00073 virtual void set_labels(CLabels* lab)
00074 {
00075 CMulticlassLabels *mlab = dynamic_cast<CMulticlassLabels *>(lab);
00076 REQUIRE(lab, "requires MulticlassLabes\n");
00077
00078 CMachine::set_labels(mlab);
00079 m_num_classes = mlab->get_num_classes();
00080 }
00081
00085 void set_machine_for_confusion_matrix(CBaseMulticlassMachine *machine)
00086 {
00087 SG_REF(machine);
00088 SG_UNREF(m_machine_for_confusion_matrix);
00089 m_machine_for_confusion_matrix = machine;
00090 }
00091
00095 void set_svm_C(float64_t C)
00096 {
00097 m_svm_C = C;
00098 }
00102 float64_t get_svm_C() const
00103 {
00104 return m_svm_C;
00105 }
00106
00110 void set_svm_epsilon(float64_t epsilon)
00111 {
00112 m_svm_epsilon = epsilon;
00113 }
00117 float64_t get_svm_epsilon() const
00118 {
00119 return m_svm_epsilon;
00120 }
00121
00127 void set_A(float64_t A)
00128 {
00129 m_A = A;
00130 }
00134 float64_t get_A() const
00135 {
00136 return m_A;
00137 }
00138
00143 void set_B(int32_t B)
00144 {
00145 m_B = B;
00146 }
00150 int32_t get_B() const
00151 {
00152 return m_B;
00153 }
00154
00158 void set_max_num_iter(int32_t n_iter)
00159 {
00160 m_max_num_iter = n_iter;
00161 }
00165 int32_t get_max_num_iter() const
00166 {
00167 return m_max_num_iter;
00168 }
00169
00179 virtual bool train(CFeatures* data=NULL)
00180 {
00181 return CMachine::train(data);
00182 }
00183
00185 typedef std::pair<std::pair<int32_t, int32_t>, float64_t> entry_t;
00186 protected:
00193 float64_t apply_one(int32_t idx);
00194
00201 virtual bool train_machine(CFeatures* data);
00202
00204 node_t *train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes);
00206 std::vector<entry_t> init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes);
00208 SGVector<int32_t> train_node_with_initialization(const CRelaxedTree::entry_t &mu_entry, SGVector<int32_t> classes, CSVM *svm);
00209
00211 float64_t compute_score(SGVector<int32_t> mu, CSVM *svm);
00213 SGVector<int32_t> color_label_space(CSVM *svm, SGVector<int32_t> classes);
00215 SGVector<float64_t> eval_binary_model_K(CSVM *svm);
00216
00218 void enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
00220 void enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
00221
00223 int32_t m_max_num_iter;
00225 float64_t m_A;
00227 int32_t m_B;
00229 float64_t m_svm_C;
00231 float64_t m_svm_epsilon;
00233 CKernel *m_kernel;
00235 CDenseFeatures<float64_t> *m_feats;
00237 CBaseMulticlassMachine *m_machine_for_confusion_matrix;
00239 int32_t m_num_classes;
00240 };
00241
00242 }
00243
00244 #endif
00245