SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CHAIDTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 
32 #ifndef _CHAIDTree_H__
33 #define _CHAIDTree_H__
34 
35 #include <shogun/lib/config.h>
36 
40 
41 namespace shogun
42 {
43 
90 class CCHAIDTree : public CTreeMachine<CHAIDTreeNodeData>
91 {
92 public:
94  CCHAIDTree();
95 
99  CCHAIDTree(int32_t dependent_vartype);
100 
106  CCHAIDTree(int32_t dependent_vartype, SGVector<int32_t> feature_types, int32_t num_breakpoints=0);
107 
109  virtual ~CCHAIDTree();
110 
114  virtual const char* get_name() const { return "CHAIDTree"; }
115 
119  virtual EProblemType get_machine_problem_type() const;
120 
125  virtual bool is_label_valid(CLabels* lab) const;
126 
133  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
134 
141  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
142 
147 
152 
154  void clear_weights();
155 
160 
165 
167  void clear_feature_types();
168 
172  void set_dependent_vartype(int32_t var);
173 
177  int32_t get_dependent_vartype() const { return m_dependent_vartype; }
178 
182  void set_max_tree_depth(int32_t d) { m_max_tree_depth=d; }
183 
187  int32_t get_specified_max_tree_depth() const { return m_max_tree_depth; }
188 
192  void set_min_node_size(int32_t size) { m_min_node_size=size; }
193 
197  int32_t get_min_node_size() const { return m_min_node_size; }
198 
202  void set_alpha_merge(float64_t a) { m_alpha_merge=a; }
203 
207  float64_t get_alpha_merge() const { return m_alpha_merge; }
208 
212  void set_alpha_split(float64_t a) { m_alpha_split=a; }
213 
217  float64_t get_alpha_split() const { return m_alpha_split; }
218 
222  void set_num_breakpoints(int32_t b) { m_num_breakpoints=b; }
223 
227  float64_t get_num_breakpoints() const { return m_num_breakpoints; }
228 
229 protected:
234  virtual bool train_machine(CFeatures* data=NULL);
235 
236 private:
245  CTreeMachineNode<CHAIDTreeNodeData>* CHAIDtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level);
246 
255  SGVector<int32_t> merge_categories_ordinal(SGVector<float64_t> feats, SGVector<float64_t> labels,
256  SGVector<float64_t> weights, float64_t &pv);
257 
266  SGVector<int32_t> merge_categories_nominal(SGVector<float64_t> feats, SGVector<float64_t> labels,
267  SGVector<float64_t> weights, float64_t &pv);
268 
274  CLabels* apply_tree(CFeatures* data);
275 
282  CLabels* apply_from_current_node(SGMatrix<float64_t> fmat, node_t* current);
283 
292  bool handle_missing_ordinal(SGVector<int32_t> cat, SGVector<float64_t> feats, SGVector<float64_t> labels, SGVector<float64_t> weights);
293 
303  float64_t adjusted_p_value(float64_t p_value, int32_t inum_cat, int32_t fnum_cat, int32_t ft, bool is_missing);
304 
313 
322  float64_t anova_f_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights, int32_t &r);
323 
333  float64_t likelihood_ratio_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights,
334  int32_t &r, int32_t &c);
335 
345  float64_t pchi2_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights, int32_t &r, int32_t &c);
346 
354  SGMatrix<float64_t> expected_cf_row_effects_model(SGMatrix<int32_t> ct, SGMatrix<float64_t> wt, SGVector<float64_t> score);
355 
362  SGMatrix<float64_t> expected_cf_indep_model(SGMatrix<int32_t> ct, SGMatrix<float64_t> wt);
363 
370  float64_t sum_of_squared_deviation(SGVector<float64_t> lab, SGVector<float64_t> weights, float64_t &mean);
371 
379  bool continuous_to_ordinal(CDenseFeatures<float64_t>* feats);
380 
386  void modify_data_matrix(CDenseFeatures<float64_t>* feats);
387 
389  void init();
390 
391 public:
393  static const float64_t MISSING;
394 
395 private:
397  SGVector<int32_t> m_feature_types;
398 
400  SGVector<float64_t> m_weights;
401 
403  bool m_weights_set;
404 
406  int32_t m_dependent_vartype;
407 
409  int32_t m_max_tree_depth;
410 
412  int32_t m_min_node_size;
413 
415  float64_t m_alpha_merge;
416 
418  float64_t m_alpha_split;
419 
421  SGMatrix<float64_t> m_cont_breakpoints;
422 
424  int32_t m_num_breakpoints;
425 
426 };
427 } /* namespace shogun */
428 
429 #endif /* _CHAIDTree_H__ */

SHOGUN Machine Learning Toolbox - Documentation