SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CARTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 
32 #ifndef _CARTREE_H__
33 #define _CARTREE_H__
34 
35 #include <shogun/lib/config.h>
36 
40 
41 namespace shogun
42 {
43 
79 class CCARTree : public CTreeMachine<CARTreeNodeData>
80 {
81 public:
83  CCARTree();
84 
89  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type=PT_MULTICLASS);
90 
97  CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune);
98 
100  virtual ~CCARTree();
101 
105  virtual void set_labels(CLabels* lab);
106 
110  virtual const char* get_name() const { return "CARTree"; }
111 
115  virtual EProblemType get_machine_problem_type() const { return m_mode; }
116 
121 
126  virtual bool is_label_valid(CLabels* lab) const;
127 
132  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
133 
138  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
139 
147 
152 
157 
159  void clear_weights();
160 
165 
170 
172  void clear_feature_types();
173 
178  int32_t get_num_folds() const;
179 
184  void set_num_folds(int32_t folds);
185 
190  int32_t get_max_depth() const;
191 
196  void set_max_depth(int32_t depth);
197 
202  int32_t get_min_node_size() const;
203 
208  void set_min_node_size(int32_t nsize);
209 
212 
215 
221 
226  void set_label_epsilon(float64_t epsilon);
227 
228  void pre_sort_features(CFeatures* data, SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
229 
230  void set_sorted_features(SGMatrix<float64_t>& sorted_feats, SGMatrix<index_t>& sorted_indices);
231 
232 protected:
237  virtual bool train_machine(CFeatures* data=NULL);
238 
247  virtual CBinaryTreeMachineNode<CARTreeNodeData>* CARTtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
248 
255  SGVector<float64_t> get_unique_labels(SGVector<float64_t> labels_vec, int32_t &n_ulabels);
256 
270  virtual int32_t compute_best_attribute(const SGMatrix<float64_t>& mat, const SGVector<float64_t>& weights, CLabels* labels,
271  SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing,
272  int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector<int32_t>& active_indices=SGVector<index_t>());
273 
274 
284 
285 
299  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
300  SGVector<float64_t> weights, float64_t p, int32_t attr);
301 
315  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
316  SGVector<float64_t> weights, float64_t p, int32_t attr);
317 
327 
335  float64_t gain(const SGVector<float64_t>& wleft, const SGVector<float64_t>& wright, const SGVector<float64_t>& wtotal);
336 
343  float64_t gini_impurity_index(const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight);
344 
352  float64_t least_squares_deviation(const SGVector<float64_t>& labels, const SGVector<float64_t>& weights, float64_t &total_weight);
353 
361 
367  void prune_by_cross_validation(CDenseFeatures<float64_t>* data, int32_t folds);
368 
378  float64_t compute_error(CLabels* labels, CLabels* reference, SGVector<float64_t> weights);
379 
386 
393 
399  void cut_weakest_link(bnode_t* node, float64_t alpha);
400 
405  void form_t1(bnode_t* node);
406 
408  void init();
409 
410 
411 public:
413  static const float64_t MISSING;
414 
416  static const float64_t MIN_SPLIT_GAIN;
417 
419  static const float64_t EQ_DELTA;
420 
421 protected:
424 
427 
430 
433 
436 
439 
442 
445 
448 
450  int32_t m_folds;
451 
454 
457 
459  int32_t m_max_depth;
460 
463 };
464 } /* namespace shogun */
465 
466 #endif /* _CARTREE_H__ */
void set_cv_pruning()
Definition: CARTree.h:211
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:1103
bool m_types_set
Definition: CARTree.h:441
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
Definition: CARTree.cpp:531
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
int32_t get_max_depth() const
Definition: CARTree.cpp:219
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:174
Real Labels are real-valued labels.
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:197
int32_t get_min_node_size() const
Definition: CARTree.cpp:230
float64_t get_label_epsilon()
Definition: CARTree.h:220
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:87
int32_t get_num_folds() const
Definition: CARTree.cpp:208
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1365
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
SGMatrix< index_t > m_sorted_indices
Definition: CARTree.h:435
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:102
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1415
void unset_cv_pruning()
Definition: CARTree.h:214
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1466
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:92
virtual const char * get_name() const
Definition: CARTree.h:110
static const float64_t EQ_DELTA
Definition: CARTree.h:419
bool m_apply_cv_pruning
Definition: CARTree.h:447
virtual ~CCARTree()
Definition: CARTree.cpp:68
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:839
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:247
int32_t m_max_depth
Definition: CARTree.h:459
float64_t m_label_epsilon
Definition: CARTree.h:423
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:73
Multiclass Labels for multi-class classification.
void clear_feature_types()
Definition: CARTree.cpp:202
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:1051
EProblemType
Definition: Machine.h:110
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
Definition: CARTree.cpp:1087
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:235
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:290
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:914
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:213
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:1077
double float64_t
Definition: common.h:50
int32_t m_min_node_size
Definition: CARTree.h:462
int32_t m_folds
Definition: CARTree.h:450
bool m_weights_set
Definition: CARTree.h:444
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:180
void clear_weights()
Definition: CARTree.cpp:185
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
SGMatrix< float64_t > m_sorted_features
Definition: CARTree.h:432
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:317
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:224
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:971
This class implements the Classification And Regression Trees algorithm by Breiman et al for decision...
Definition: CARTree.h:79
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:297
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:191
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual EProblemType get_machine_problem_type() const
Definition: CARTree.h:115
SGVector< bool > m_nominal
Definition: CARTree.h:426
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1327
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1436
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:416
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:116
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:128
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
SGVector< float64_t > m_weights
Definition: CARTree.h:429
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:456
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:507
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:241
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1188
static const float64_t MISSING
Definition: CARTree.h:413
EProblemType m_mode
Definition: CARTree.h:453

SHOGUN Machine Learning Toolbox - Documentation