SHOGUN  6.1.3
CARTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 
32 #ifndef _CARTREE_H__
33 #define _CARTREE_H__
34 
35 #include <shogun/lib/config.h>
36 
40 
41 namespace shogun
42 {
43 
79 class CCARTree : public CTreeMachine<CARTreeNodeData>
80 {
81 public:
83  CCARTree();
84 
89  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type=PT_MULTICLASS);
90 
97  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune);
98 
100  virtual ~CCARTree();
101 
105  virtual void set_labels(CLabels* lab);
106 
110  virtual const char* get_name() const { return "CARTree"; }
111 
115  virtual EProblemType get_machine_problem_type() const { return m_mode; }
116 
121 
126  virtual bool is_label_valid(CLabels* lab) const;
127 
132  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
133 
138  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
139 
147 
152 
157 
159  void clear_weights();
160 
165 
170 
172  void clear_feature_types();
173 
178  int32_t get_num_folds() const;
179 
184  void set_num_folds(int32_t folds);
185 
190  int32_t get_max_depth() const;
191 
196  void set_max_depth(int32_t depth);
197 
202  int32_t get_min_node_size() const;
203 
208  void set_min_node_size(int32_t nsize);
209 
214  void set_cv_pruning(bool cv_pruning)
215  {
216  m_apply_cv_pruning = cv_pruning;
217  }
218 
224 
229  void set_label_epsilon(float64_t epsilon);
230 
231  void pre_sort_features(CFeatures* data, SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
232 
233  void set_sorted_features(SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
234 
235 protected:
240  virtual bool train_machine(CFeatures* data=NULL);
241 
250  virtual CBinaryTreeMachineNode<CARTreeNodeData>* CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
251 
258  SGVector<float64_t> get_unique_labels(SGVector<float64_t> labels_vec, int32_t &n_ulabels);
259 
273  virtual int32_t compute_best_attribute(const SGMatrix<float64_t>& mat, const SGVector<float64_t>& weights, CLabels* labels,
274  SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing,
275  int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector<int32_t>& active_indices=SGVector<index_t>());
276 
277 
287 
288 
302  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
303  SGVector<float64_t> weights, float64_t p, int32_t attr);
304 
318  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
319  SGVector<float64_t> weights, float64_t p, int32_t attr);
320 
330 
338  float64_t gain(const SGVector<float64_t>& wleft, const SGVector<float64_t>& wright, const SGVector<float64_t>& wtotal);
339 
346  float64_t gini_impurity_index(const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight);
347 
355  float64_t least_squares_deviation(const SGVector<float64_t>& labels, const SGVector<float64_t>& weights, float64_t &total_weight);
356 
364 
370  void prune_by_cross_validation(CDenseFeatures<float64_t>* data, int32_t folds);
371 
381  float64_t compute_error(CLabels* labels, CLabels* reference, SGVector<float64_t> weights);
382 
389 
396 
402  void cut_weakest_link(bnode_t* node, float64_t alpha);
403 
408  void form_t1(bnode_t* node);
409 
411  void init();
412 
413 
414 public:
416  static const float64_t MISSING;
417 
419  static const float64_t MIN_SPLIT_GAIN;
420 
422  static const float64_t EQ_DELTA;
423 
424 protected:
427 
430 
433 
436 
439 
442 
445 
448 
451 
453  int32_t m_folds;
454 
457 
460 
462  int32_t m_max_depth;
463 
466 };
467 } /* namespace shogun */
468 
469 #endif /* _CARTREE_H__ */
void set_cv_pruning(bool cv_pruning)
Definition: CARTree.h:214
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:1104
bool m_types_set
Definition: CARTree.h:444
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
Definition: CARTree.cpp:530
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
int32_t get_max_depth() const
Definition: CARTree.cpp:218
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:173
Real Labels are real-valued labels.
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:196
int32_t get_min_node_size() const
Definition: CARTree.cpp:229
float64_t get_label_epsilon()
Definition: CARTree.h:223
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:86
int32_t get_num_folds() const
Definition: CARTree.cpp:207
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1366
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
SGMatrix< index_t > m_sorted_indices
Definition: CARTree.h:438
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:101
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1416
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1467
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:91
virtual const char * get_name() const
Definition: CARTree.h:110
static const float64_t EQ_DELTA
Definition: CARTree.h:422
bool m_apply_cv_pruning
Definition: CARTree.h:450
virtual ~CCARTree()
Definition: CARTree.cpp:67
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:840
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:246
int32_t m_max_depth
Definition: CARTree.h:462
float64_t m_label_epsilon
Definition: CARTree.h:426
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:72
Multiclass Labels for multi-class classification.
void clear_feature_types()
Definition: CARTree.cpp:201
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:1052
EProblemType
Definition: Machine.h:113
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
Definition: CARTree.cpp:1088
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:234
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:289
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:915
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:212
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:1078
double float64_t
Definition: common.h:60
int32_t m_min_node_size
Definition: CARTree.h:465
int32_t m_folds
Definition: CARTree.h:453
bool m_weights_set
Definition: CARTree.h:447
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:179
void clear_weights()
Definition: CARTree.cpp:184
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
SGMatrix< float64_t > m_sorted_features
Definition: CARTree.h:435
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:316
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:223
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:972
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
Definition: CARTree.h:79
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:296
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:190
The class Features is the base class of all feature objects.
Definition: Features.h:69
virtual EProblemType get_machine_problem_type() const
Definition: CARTree.h:115
SGVector< bool > m_nominal
Definition: CARTree.h:429
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1328
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1437
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:419
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:115
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:127
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
SGVector< float64_t > m_weights
Definition: CARTree.h:432
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:459
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:506
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:240
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1189
static const float64_t MISSING
Definition: CARTree.h:416
EProblemType m_mode
Definition: CARTree.h:456

SHOGUN Machine Learning Toolbox - Documentation