SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelaxedTree.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #ifndef RELAXEDTREE_H__
12 #define RELAXEDTREE_H__
13 
14 #include <utility>
15 #include <vector>
16 
21 
22 namespace shogun
23 {
24 
25 class CBaseMulticlassMachine;
26 
34 class CRelaxedTree: public CTreeMachine<RelaxedTreeNodeData>
35 {
36 public:
38  CRelaxedTree();
39 
41  virtual ~CRelaxedTree();
42 
44  virtual const char* get_name() const { return "RelaxedTree"; }
45 
47  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
48 
53  {
54  SG_REF(feats);
56  m_feats = feats;
57  }
58 
62  virtual void set_kernel(CKernel *kernel)
63  {
64  SG_REF(kernel);
66  m_kernel = kernel;
67  }
68 
73  virtual void set_labels(CLabels* lab)
74  {
75  CMulticlassLabels *mlab = dynamic_cast<CMulticlassLabels *>(lab);
76  REQUIRE(lab, "requires MulticlassLabes\n");
77 
80  }
81 
86  {
87  SG_REF(machine);
90  }
91 
96  {
97  m_svm_C = C;
98  }
103  {
104  return m_svm_C;
105  }
106 
111  {
113  }
118  {
119  return m_svm_epsilon;
120  }
121 
127  void set_A(float64_t A)
128  {
129  m_A = A;
130  }
134  float64_t get_A() const
135  {
136  return m_A;
137  }
138 
143  void set_B(int32_t B)
144  {
145  m_B = B;
146  }
150  int32_t get_B() const
151  {
152  return m_B;
153  }
154 
158  void set_max_num_iter(int32_t n_iter)
159  {
160  m_max_num_iter = n_iter;
161  }
165  int32_t get_max_num_iter() const
166  {
167  return m_max_num_iter;
168  }
169 
179  virtual bool train(CFeatures* data=NULL)
180  {
181  return CMachine::train(data);
182  }
183 
185  typedef std::pair<std::pair<int32_t, int32_t>, float64_t> entry_t;
186 protected:
193  float64_t apply_one(int32_t idx);
194 
201  virtual bool train_machine(CFeatures* data);
202 
204  node_t *train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes);
206  std::vector<entry_t> init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes);
209 
216 
218  void enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
220  void enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
221 
223  int32_t m_max_num_iter;
227  int32_t m_B;
239  int32_t m_num_classes;
240 };
241 
242 } /* shogun */
243 
244 #endif /* end of include guard: RELAXEDTREE_H__ */
245 

SHOGUN Machine Learning Toolbox - Documentation