SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
RelaxedTree.h
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #ifndef RELAXEDTREE_H__
12 #define RELAXEDTREE_H__
13 
14 #include <utility>
15 #include <vector>
16 
17 #include <shogun/lib/config.h>
18 
23 
24 namespace shogun
25 {
26 
27 class CBaseMulticlassMachine;
28 
36 class CRelaxedTree: public CTreeMachine<RelaxedTreeNodeData>
37 {
38 public:
40  CRelaxedTree();
41 
43  virtual ~CRelaxedTree();
44 
46  virtual const char* get_name() const { return "RelaxedTree"; }
47 
49  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
50 
55  {
56  SG_REF(feats);
58  m_feats = feats;
59  }
60 
64  virtual void set_kernel(CKernel *kernel)
65  {
66  SG_REF(kernel);
68  m_kernel = kernel;
69  }
70 
75  virtual void set_labels(CLabels* lab)
76  {
77  CMulticlassLabels *mlab = dynamic_cast<CMulticlassLabels *>(lab);
78  REQUIRE(lab, "requires MulticlassLabes\n")
79 
82  }
83 
88  {
89  SG_REF(machine);
92  }
93 
98  {
99  m_svm_C = C;
100  }
105  {
106  return m_svm_C;
107  }
108 
113  {
115  }
120  {
121  return m_svm_epsilon;
122  }
123 
129  void set_A(float64_t A)
130  {
131  m_A = A;
132  }
136  float64_t get_A() const
137  {
138  return m_A;
139  }
140 
145  void set_B(int32_t B)
146  {
147  m_B = B;
148  }
152  int32_t get_B() const
153  {
154  return m_B;
155  }
156 
160  void set_max_num_iter(int32_t n_iter)
161  {
162  m_max_num_iter = n_iter;
163  }
167  int32_t get_max_num_iter() const
168  {
169  return m_max_num_iter;
170  }
171 
181  virtual bool train(CFeatures* data=NULL)
182  {
183  return CMachine::train(data);
184  }
185 
187  typedef std::pair<std::pair<int32_t, int32_t>, float64_t> entry_t;
188 protected:
195  float64_t apply_one(int32_t idx);
196 
203  virtual bool train_machine(CFeatures* data);
204 
206  bnode_t *train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes);
208  std::vector<entry_t> init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes);
211 
218 
220  void enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
222  void enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class);
223 
225  int32_t m_max_num_iter;
229  int32_t m_B;
241  int32_t m_num_classes;
242 };
243 
244 } /* shogun */
245 
246 #endif /* end of include guard: RELAXEDTREE_H__ */
247 
std::pair< std::pair< int32_t, int32_t >, float64_t > entry_t
Definition: RelaxedTree.h:187
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
SGVector< float64_t > eval_binary_model_K(CSVM *svm)
void set_svm_epsilon(float64_t epsilon)
Definition: RelaxedTree.h:112
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual bool train(CFeatures *data=NULL)
Definition: RelaxedTree.h:181
float64_t compute_score(SGVector< int32_t > mu, CSVM *svm)
#define REQUIRE(x,...)
Definition: SGIO.h:206
void set_svm_C(float64_t C)
Definition: RelaxedTree.h:97
void set_machine_for_confusion_matrix(CBaseMulticlassMachine *machine)
Definition: RelaxedTree.h:87
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: RelaxedTree.cpp:42
#define SG_REF(x)
Definition: SGObject.h:51
float64_t apply_one(int32_t idx)
Definition: RelaxedTree.cpp:73
Multiclass Labels for multi-class classification.
static const float64_t epsilon
Definition: libbmrm.cpp:25
SGVector< int32_t > train_node_with_initialization(const CRelaxedTree::entry_t &mu_entry, SGVector< int32_t > classes, CSVM *svm)
void enforce_balance_constraints_lower(SGVector< int32_t > &mu, SGVector< float64_t > &delta_neg, SGVector< float64_t > &delta_pos, int32_t B_prime, SGVector< float64_t > &xi_neg_class)
void set_features(CDenseFeatures< float64_t > *feats)
Definition: RelaxedTree.h:54
void set_max_num_iter(int32_t n_iter)
Definition: RelaxedTree.h:160
double float64_t
Definition: common.h:50
void set_B(int32_t B)
Definition: RelaxedTree.h:145
CDenseFeatures< float64_t > * m_feats
Definition: RelaxedTree.h:237
virtual const char * get_name() const
Definition: RelaxedTree.h:46
void set_A(float64_t A)
Definition: RelaxedTree.h:129
float64_t get_svm_epsilon() const
Definition: RelaxedTree.h:119
SGVector< int32_t > color_label_space(CSVM *svm, SGVector< int32_t > classes)
virtual bool train_machine(CFeatures *data)
float64_t get_svm_C() const
Definition: RelaxedTree.h:104
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
void enforce_balance_constraints_upper(SGVector< int32_t > &mu, SGVector< float64_t > &delta_neg, SGVector< float64_t > &delta_pos, int32_t B_prime, SGVector< float64_t > &xi_neg_class)
bnode_t * train_node(const SGMatrix< float64_t > &conf_mat, SGVector< int32_t > classes)
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual void set_kernel(CKernel *kernel)
Definition: RelaxedTree.h:64
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:39
A generic Support Vector Machine Interface.
Definition: SVM.h:49
float64_t get_A() const
Definition: RelaxedTree.h:136
The Kernel base class.
Definition: Kernel.h:158
int32_t get_max_num_iter() const
Definition: RelaxedTree.h:167
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
virtual void set_labels(CLabels *lab)
Definition: RelaxedTree.h:75
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:65
std::vector< entry_t > init_node(const SGMatrix< float64_t > &global_conf_mat, SGVector< int32_t > classes)
int32_t get_B() const
Definition: RelaxedTree.h:152
float64_t m_svm_epsilon
Definition: RelaxedTree.h:233
CBaseMulticlassMachine * m_machine_for_confusion_matrix
Definition: RelaxedTree.h:239

SHOGUN 机器学习工具包 - 项目文档