SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
CARTree.h
浏览该文件的文档.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 
32 #ifndef _CARTREE_H__
33 #define _CARTREE_H__
34 
35 #include <shogun/lib/config.h>
36 
40 
41 namespace shogun
42 {
43 
79 class CCARTree : public CTreeMachine<CARTreeNodeData>
80 {
81 public:
83  CCARTree();
84 
89  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type=PT_MULTICLASS);
90 
97  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune);
98 
100  virtual ~CCARTree();
101 
105  virtual void set_labels(CLabels* lab);
106 
110  virtual const char* get_name() const { return "CARTree"; }
111 
115  virtual EProblemType get_machine_problem_type() const { return m_mode; }
116 
121 
126  virtual bool is_label_valid(CLabels* lab) const;
127 
132  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
133 
138  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
139 
147 
152 
157 
159  void clear_weights();
160 
165 
170 
172  void clear_feature_types();
173 
178  int32_t get_num_folds() const;
179 
184  void set_num_folds(int32_t folds);
185 
190  int32_t get_max_depth() const;
191 
196  void set_max_depth(int32_t depth);
197 
202  int32_t get_min_node_size() const;
203 
208  void set_min_node_size(int32_t nsize);
209 
212 
215 
221 
227 
228 protected:
229 
234  virtual bool train_machine(CFeatures* data=NULL);
235 
244  virtual CBinaryTreeMachineNode<CARTreeNodeData>* CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
245 
252  SGVector<float64_t> get_unique_labels(SGVector<float64_t> labels_vec, int32_t &n_ulabels);
253 
268  SGVector<float64_t> left, SGVector<float64_t> right, SGVector<bool> is_left_final, int32_t &num_missing,
269  int32_t &count_left, int32_t &count_right);
270 
271 
281 
282 
296  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
297  SGVector<float64_t> weights, float64_t p, int32_t attr);
298 
312  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
313  SGVector<float64_t> weights, float64_t p, int32_t attr);
314 
324 
333 
340  float64_t gini_impurity_index(SGVector<float64_t> weighted_lab_classes, float64_t &total_weight);
341 
350 
358 
364  void prune_by_cross_validation(CDenseFeatures<float64_t>* data, int32_t folds);
365 
375  float64_t compute_error(CLabels* labels, CLabels* reference, SGVector<float64_t> weights);
376 
383 
390 
396  void cut_weakest_link(bnode_t* node, float64_t alpha);
397 
402  void form_t1(bnode_t* node);
403 
405  void init();
406 
407 public:
409  static const float64_t MISSING;
410 
412  static const float64_t MIN_SPLIT_GAIN;
413 
415  static const float64_t EQ_DELTA;
416 
417 protected:
420 
423 
426 
429 
432 
435 
437  int32_t m_folds;
438 
441 
444 
446  int32_t m_max_depth;
447 
450 };
451 } /* namespace shogun */
452 
453 #endif /* _CARTREE_H__ */
void set_cv_pruning()
Definition: CARTree.h:211
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:976
bool m_types_set
Definition: CARTree.h:428
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
int32_t get_max_depth() const
Definition: CARTree.cpp:214
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:169
Real Labels are real-valued labels.
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:192
int32_t get_min_node_size() const
Definition: CARTree.cpp:225
float64_t get_label_epsilon()
Definition: CARTree.h:220
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:84
int32_t get_num_folds() const
Definition: CARTree.cpp:203
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1236
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:99
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1284
float64_t least_squares_deviation(SGVector< float64_t > labels, SGVector< float64_t > weights, float64_t &total_weight)
Definition: CARTree.cpp:958
void unset_cv_pruning()
Definition: CARTree.h:214
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1335
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:89
virtual const char * get_name() const
Definition: CARTree.h:110
static const float64_t EQ_DELTA
Definition: CARTree.h:415
bool m_apply_cv_pruning
Definition: CARTree.h:434
virtual ~CCARTree()
Definition: CARTree.cpp:65
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:706
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:242
int32_t m_max_depth
Definition: CARTree.h:446
float64_t m_label_epsilon
Definition: CARTree.h:419
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:70
Multiclass Labels for multi-class classification.
static const float64_t epsilon
Definition: libbmrm.cpp:25
float64_t gini_impurity_index(SGVector< float64_t > weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:944
void clear_feature_types()
Definition: CARTree.cpp:197
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:918
EProblemType
Definition: Machine.h:110
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:230
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:781
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:208
double float64_t
Definition: common.h:50
int32_t m_min_node_size
Definition: CARTree.h:449
int32_t m_folds
Definition: CARTree.h:437
bool m_weights_set
Definition: CARTree.h:431
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:175
void clear_weights()
Definition: CARTree.cpp:180
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:285
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:219
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:838
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
Definition: CARTree.h:79
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:186
virtual int32_t compute_best_attribute(SGMatrix< float64_t > mat, SGVector< float64_t > weights, SGVector< float64_t > labels_vec, SGVector< float64_t > left, SGVector< float64_t > right, SGVector< bool > is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right)
Definition: CARTree.cpp:486
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual EProblemType get_machine_problem_type() const
Definition: CARTree.h:115
SGVector< bool > m_nominal
Definition: CARTree.h:422
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1198
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1305
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:412
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:111
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:123
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
SGVector< float64_t > m_weights
Definition: CARTree.h:425
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:443
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:462
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:236
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1059
static const float64_t MISSING
Definition: CARTree.h:409
EProblemType m_mode
Definition: CARTree.h:440

SHOGUN 机器学习工具包 - 项目文档