SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CHAIDTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
34 
35 using namespace shogun;
36 
38 
41 {
42  init();
43 }
44 
45 CCHAIDTree::CCHAIDTree(int32_t dependent_vartype)
47 {
48  init();
49  m_dependent_vartype=dependent_vartype;
50 }
51 
52 CCHAIDTree::CCHAIDTree(int32_t dependent_vartype, SGVector<int32_t> feature_types, int32_t num_breakpoints)
54 {
55  init();
56  m_dependent_vartype=dependent_vartype;
57  set_feature_types(feature_types);
58  m_num_breakpoints=num_breakpoints;
59 }
60 
62 {
63 }
64 
66 {
67  switch (m_dependent_vartype)
68  {
69  case 0:
70  return PT_MULTICLASS;
71  case 1:
72  return PT_MULTICLASS;
73  case 2:
74  return PT_REGRESSION;
75  default:
76  SG_ERROR("Invalid dependent variable type set (%d). Problem type undefined\n",m_dependent_vartype);
77  }
78 
79  return PT_MULTICLASS;
80 }
81 
83 {
84  switch (m_dependent_vartype)
85  {
86  case 0:
87  return lab->get_label_type()==LT_MULTICLASS;
88  case 1:
89  return lab->get_label_type()==LT_MULTICLASS;
90  case 2:
91  return lab->get_label_type()==LT_REGRESSION;
92  default:
93  SG_ERROR("Invalid dependent variable type set (%d). Problem type undefined\n",m_dependent_vartype);
94  }
95 
96  return false;
97 }
98 
100 {
101  REQUIRE(data, "Data required for classification in apply_multiclass\n")
102 
103  return CLabelsFactory::to_multiclass(apply_tree(data));
104 }
105 
107 {
108  REQUIRE(data, "Data required for regression in apply_regression\n")
109 
110  return CLabelsFactory::to_regression(apply_tree(data));
111 }
112 
114 {
115  m_weights=w;
116  m_weights_set=true;
117 }
118 
120 {
121  if (!m_weights_set)
122  SG_ERROR("weights not set\n");
123 
124  return m_weights;
125 }
126 
128 {
129  m_weights=SGVector<float64_t>();
130  m_weights_set=false;
131 }
132 
134 {
135  m_feature_types=ft;
136 }
137 
139 {
140  return m_feature_types;
141 }
142 
144 {
145  m_feature_types=SGVector<int32_t>();
146 }
147 
149 {
150  REQUIRE(((var==0)||(var==1)||(var==2)), "Expected 0 or 1 or 2 as argument. %d received\n",var)
151  m_dependent_vartype=var;
152 }
153 
155 {
156  REQUIRE(data, "Data required for training\n")
157  REQUIRE(data->get_feature_class()==C_DENSE,"Dense data required for training\n")
158 
160 
161  REQUIRE(m_feature_types.vlen==feats->get_num_features(),"Either feature types are not set or the number of feature types specified"
162  " (%d here) is not the same as the number of features in data matrix (%d here)\n",m_feature_types.vlen,feats->get_num_features())
163 
164  if (m_weights_set)
165  {
166  REQUIRE(m_weights.vlen==feats->get_num_vectors(),"Length of weights vector (currently %d) should be same as"
167  " the number of vectors in data (presently %d)",m_weights.vlen,feats->get_num_vectors())
168  }
169  else
170  {
171  // all weights are equal to 1
172  m_weights=SGVector<float64_t>(feats->get_num_vectors());
173  m_weights.fill_vector(m_weights.vector,m_weights.vlen,1.0);
174  }
175 
176  // continuous to ordinal conversion - NOTE: data matrix gets updated
177  bool updated=continuous_to_ordinal(feats);
178 
179  SGVector<int32_t> feature_types_cache;
180  if (updated)
181  {
182  // change m_feature_types momentarily
183  feature_types_cache=m_feature_types.clone();
184  for (int32_t i=0;i<m_feature_types.vlen;i++)
185  {
186  if (m_feature_types[i]==2)
187  m_feature_types[i]=1;
188  }
189  }
190 
191  set_root(CHAIDtrain(data,m_weights,m_labels,0));
192 
193  // restore feature types
194  if (updated)
195  m_feature_types=feature_types_cache;
196 
197  return true;
198 }
199 
200 CTreeMachineNode<CHAIDTreeNodeData>* CCHAIDTree::CHAIDtrain(CFeatures* data, SGVector<float64_t> weights, CLabels* labels, int32_t level)
201 {
202  REQUIRE(data,"data matrix cannot be empty\n");
203  REQUIRE(labels,"labels cannot be NULL\n");
204 
205  node_t* node=new node_t();
206  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
208  int32_t num_feats=mat.num_rows;
209  int32_t num_vecs=mat.num_cols;
210 
211  // calculate node label
212  if (m_dependent_vartype==2)
213  {
214  // sum_of_squared_deviation
215  node->data.weight_minus_node=sum_of_squared_deviation(labels_vec,weights,node->data.node_label);
216  node->data.total_weight=weights.sum(weights);
217  }
218  else if (m_dependent_vartype==0 || m_dependent_vartype==1)
219  {
220  SGVector<float64_t> lab=labels_vec.clone();
221  lab.qsort();
222  // stores max total weight for a single label
223  int32_t max=weights[0];
224  // stores one of the indices having max total weight
225  int32_t maxi=0;
226  int32_t c=weights[0];
227  for (int32_t i=1;i<lab.vlen;i++)
228  {
229  if (lab[i]==lab[i-1])
230  {
231  c+=weights[i];
232  }
233  else if (c>max)
234  {
235  max=c;
236  maxi=i-1;
237  c=weights[i];
238  }
239  else
240  {
241  c=weights[i];
242  }
243  }
244 
245  if (c>max)
246  {
247  max=c;
248  maxi=lab.vlen-1;
249  }
250 
251  node->data.node_label=lab[maxi];
252  node->data.total_weight=weights.sum(weights);
254 
255  }
256  else
257  {
258  SG_ERROR("dependent variable type should be either 0(nominal) or 1(ordinal) or 2(continuous)\n");
259  }
260 
261  // check stopping rules
262  // case 1 : all labels same
263  SGVector<float64_t> lab=labels_vec.clone();
264  int32_t unique=lab.unique(lab.vector,lab.vlen);
265  if (unique==1)
266  return node;
267 
268  // case 2 : all non-dependent attributes (not MISSING) are same
269  bool flag=true;
270  for (int32_t v=1;v<num_vecs;v++)
271  {
272  for (int32_t f=0;f<num_feats;f++)
273  {
274  if ((mat(f,v)!=MISSING) && (mat(f,v-1)!=MISSING))
275  {
276  if (mat(f,v)!=mat(f,v-1))
277  {
278  flag=false;
279  break;
280  }
281  }
282  }
283 
284  if (!flag)
285  break;
286  }
287 
288  if (flag)
289  return node;
290 
291  // case 3 : current tree depth is equal to user specified max
292  if (m_max_tree_depth>0)
293  {
294  if (level==m_max_tree_depth)
295  return node;
296  }
297 
298  // case 4 : node size is less than user-specified min node size
299  if (m_min_node_size>1)
300  {
301  if (num_vecs<m_min_node_size)
302  return node;
303  }
304 
305  // choose best attribute for splitting
307  SGVector<int32_t> cat_min;
308  int32_t attr_min=-1;
309  for (int32_t i=0;i<num_feats;i++)
310  {
311  SGVector<float64_t> feats(num_vecs);
312  for (int32_t j=0;j<num_vecs;j++)
313  feats[j]=mat(i,j);
314 
315  float64_t pv=0;
316  SGVector<int32_t> cat;
317  if (m_feature_types[i]==0)
318  cat=merge_categories_nominal(feats,labels_vec,weights,pv);
319  else if (m_feature_types[i]==1)
320  cat=merge_categories_ordinal(feats,labels_vec,weights,pv);
321  else
322  SG_ERROR("feature type supported are 0(nominal) and 1(ordinal). m_feature_types[%d] is set %d\n",i,m_feature_types[i])
323 
324  if (pv<min_pv)
325  {
326  min_pv=pv;
327  attr_min=i;
328  cat_min=cat;
329  }
330  }
331 
332  if (min_pv>m_alpha_split)
333  return node;
334 
335  // split
336  SGVector<float64_t> ufeats_best(num_vecs);
337  for (int32_t i=0;i<num_vecs;i++)
338  ufeats_best[i]=mat(attr_min,i);
339 
340  int32_t unum=ufeats_best.unique(ufeats_best.vector,ufeats_best.vlen);
341  for (int32_t i=0;i<cat_min.vlen;i++)
342  {
343  if (cat_min[i]!=i)
344  continue;
345 
347  for (int32_t j=0;j<num_vecs;j++)
348  {
349  for (int32_t k=0;k<unum;k++)
350  {
351  if (mat(attr_min,j)==ufeats_best[k])
352  {
353  if (cat_min[k]==i)
354  feat_index->push_back(j);
355  }
356  }
357  }
358 
359  SGVector<int32_t> subset(feat_index->get_num_elements());
360  SGVector<float64_t> subweights(feat_index->get_num_elements());
361  for (int32_t j=0;j<feat_index->get_num_elements();j++)
362  {
363  subset[j]=feat_index->get_element(j);
364  subweights[j]=weights[feat_index->get_element(j)];
365  }
366 
367  data->add_subset(subset);
368  labels->add_subset(subset);
369  node_t* child=CHAIDtrain(data,subweights,labels,level+1);
370  node->add_child(child);
371 
372  node->data.attribute_id=attr_min;
373  int32_t c=0;
374  SGVector<int32_t> feat_class=cat_min.clone();
375  for (int32_t j=0;j<feat_class.vlen;j++)
376  {
377  if (feat_class[j]!=j)
378  {
379  continue;
380  }
381  else if (j==c)
382  {
383  c++;
384  continue;
385  }
386 
387  for (int32_t k=j;k<feat_class.vlen;k++)
388  {
389  if (feat_class[k]==j)
390  feat_class[k]=c;
391  }
392 
393  c++;
394  }
395 
396  node->data.feature_class=feat_class;
398  for (int32_t j=0;j<unum;j++)
399  node->data.distinct_features[j]=ufeats_best[j];
400 
401  SG_UNREF(feat_index);
402  data->remove_subset();
403  labels->remove_subset();
404  }
405 
406  return node;
407 }
408 
409 SGVector<int32_t> CCHAIDTree::merge_categories_ordinal(SGVector<float64_t> feats, SGVector<float64_t> labels,
410  SGVector<float64_t> weights, float64_t &pv)
411 {
412  SGVector<float64_t> ufeats=feats.clone();
413  int32_t inum_cat=ufeats.unique(ufeats.vector,ufeats.vlen);
414  SGVector<int32_t> cat(inum_cat);
415  cat.range_fill(0);
416 
417  if (inum_cat==1)
418  {
419  pv=1.0;
420  return cat;
421  }
422 
423  bool missing=false;
424  if (ufeats[inum_cat-1]==MISSING)
425  {
426  missing=true;
427  inum_cat--;
428  }
429 
430  int32_t fnum_cat=inum_cat;
431 
432  // if chosen attribute (MISSING excluded) has 1 category only
433  if (inum_cat==1)
434  {
435  pv=adjusted_p_value(p_value(feats,labels,weights),2,2,1,true);
436  return cat;
437  }
438 
439  while(true)
440  {
441  if (fnum_cat==2)
442  break;
443 
444  // scan all allowable pairs of categories to find most similar one
445  int32_t cat_index_max=-1;
446  float64_t max_merge_pv=CMath::MIN_REAL_NUMBER;
447  for (int32_t i=0;i<inum_cat-1;i++)
448  {
449  if (cat[i]==cat[i+1])
450  continue;
451 
452  int32_t cat_index=i;
453 
454  // compute p-value
457  for (int32_t j=0;j<feats.vlen;j++)
458  {
459  for (int32_t k=0;k<inum_cat;k++)
460  {
461  if (feats[j]==ufeats[k])
462  {
463  if (cat[k]==cat[cat_index])
464  {
465  feat_index->push_back(j);
466  feat_cat->push_back(cat[cat_index]);
467  }
468  else if (cat[k]==cat[cat_index+1])
469  {
470  feat_index->push_back(j);
471  feat_cat->push_back(cat[cat_index+1]);
472  }
473  }
474  }
475  }
476 
477  SGVector<float64_t> subfeats(feat_index->get_num_elements());
478  SGVector<float64_t> sublabels(feat_index->get_num_elements());
479  SGVector<float64_t> subweights(feat_index->get_num_elements());
480  for (int32_t j=0;j<feat_index->get_num_elements();j++)
481  {
482  subfeats[j]=feat_cat->get_element(j);
483  sublabels[j]=labels[feat_index->get_element(j)];
484  subweights[j]=weights[feat_index->get_element(j)];
485  }
486 
487  float64_t pvalue=p_value(subfeats,sublabels,subweights);
488  if (pvalue>max_merge_pv)
489  {
490  max_merge_pv=pvalue;
491  cat_index_max=cat_index;
492  }
493 
494  SG_UNREF(feat_index);
495  SG_UNREF(feat_cat);
496  }
497 
498  if (max_merge_pv>m_alpha_merge)
499  {
500  // merge
501  int32_t cat2=cat[cat_index_max+1];
502  for (int32_t i=cat_index_max+1;i<inum_cat;i++)
503  {
504  if (cat2==cat[i])
505  cat[i]=cat[cat_index_max];
506  else
507  break;
508  }
509 
510  fnum_cat--;
511  }
512  else
513  {
514  break;
515  }
516  }
517 
518  SGVector<float64_t> feats_cat(feats.vlen);
519  for (int32_t i=0;i<feats.vlen;i++)
520  {
521  if (feats[i]==MISSING)
522  {
523  feats_cat[i]=MISSING;
524  continue;
525  }
526 
527  for (int32_t j=0;j<inum_cat;j++)
528  {
529  if (feats[i]==ufeats[j])
530  feats_cat[i]=cat[j];
531  }
532  }
533 
534  if (missing)
535  {
536  bool merged=handle_missing_ordinal(cat,feats_cat,labels,weights);
537  if (!merged)
538  fnum_cat+=1;
539 
540  pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat+1,fnum_cat,1,true);
541  }
542  else
543  {
544  pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat,fnum_cat,1,false);
545  }
546 
547  return cat;
548 }
549 
550 SGVector<int32_t> CCHAIDTree::merge_categories_nominal(SGVector<float64_t> feats, SGVector<float64_t> labels,
551  SGVector<float64_t> weights, float64_t &pv)
552 {
553  SGVector<float64_t> ufeats=feats.clone();
554  int32_t inum_cat=ufeats.unique(ufeats.vector,ufeats.vlen);
555  int32_t fnum_cat=inum_cat;
556 
557  SGVector<int32_t> cat(inum_cat);
558  cat.range_fill(0);
559 
560  // if chosen attribute X(feats here) has 1 category only
561  if (inum_cat==1)
562  {
563  pv=1.0;
564  return cat;
565  }
566 
567  while(true)
568  {
569  if (fnum_cat==2)
570  break;
571 
572  // assimilate all category labels left
574  for (int32_t i=0;i<cat.vlen;i++)
575  {
576  if (cat[i]==i)
577  leftcat->push_back(i);
578  }
579 
580  // consider all pairs for merging
581  float64_t max_merge_pv=CMath::MIN_REAL_NUMBER;
582  int32_t cat1_max=-1;
583  int32_t cat2_max=-1;
584  for (int32_t i=0;i<leftcat->get_num_elements()-1;i++)
585  {
586  for (int32_t j=i+1;j<leftcat->get_num_elements();j++)
587  {
590  for (int32_t k=0;k<feats.vlen;k++)
591  {
592  for (int32_t l=0;l<inum_cat;l++)
593  {
594  if (feats[k]==ufeats[l])
595  {
596  if (cat[l]==leftcat->get_element(i))
597  {
598  feat_index->push_back(k);
599  feat_cat->push_back(leftcat->get_element(i));
600  }
601  else if (cat[l]==leftcat->get_element(j))
602  {
603  feat_index->push_back(k);
604  feat_cat->push_back(leftcat->get_element(j));
605  }
606  }
607  }
608  }
609 
610  SGVector<float64_t> subfeats(feat_index->get_num_elements());
611  SGVector<float64_t> sublabels(feat_index->get_num_elements());
612  SGVector<float64_t> subweights(feat_index->get_num_elements());
613  for (int32_t k=0;k<feat_index->get_num_elements();k++)
614  {
615  subfeats[k]=feat_cat->get_element(k);
616  sublabels[k]=labels[feat_index->get_element(k)];
617  subweights[k]=weights[feat_index->get_element(k)];
618  }
619 
620  float64_t pvalue=p_value(subfeats,sublabels,subweights);
621  if (pvalue>max_merge_pv)
622  {
623  max_merge_pv=pvalue;
624  cat1_max=leftcat->get_element(i);
625  cat2_max=leftcat->get_element(j);
626  }
627 
628  SG_UNREF(feat_index);
629  SG_UNREF(feat_cat);
630  }
631  }
632 
633  SG_UNREF(leftcat);
634 
635  if (max_merge_pv>m_alpha_merge)
636  {
637  // merge
638  for (int32_t i=0;i<cat.vlen;i++)
639  {
640  if (cat2_max==cat[i])
641  cat[i]=cat1_max;
642  }
643 
644  fnum_cat--;
645  }
646  else
647  {
648  break;
649  }
650  }
651 
652  SGVector<float64_t> feats_cat(feats.vlen);
653  for (int32_t i=0;i<feats.vlen;i++)
654  {
655  for (int32_t j=0;j<inum_cat;j++)
656  {
657  if (feats[i]==ufeats[j])
658  feats_cat[i]=cat[j];
659  }
660  }
661 
662  pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat,fnum_cat,0,false);
663  return cat;
664 }
665 
666 CLabels* CCHAIDTree::apply_tree(CFeatures* data)
667 {
669 
670  // modify test data matrix (continuous to ordinal)
671  if (m_cont_breakpoints.num_cols>0)
672  modify_data_matrix(feats);
673 
675  return apply_from_current_node(fmat, m_root);
676 }
677 
678 CLabels* CCHAIDTree::apply_from_current_node(SGMatrix<float64_t> fmat, node_t* current)
679 {
680  int32_t num_vecs=fmat.num_cols;
681 
682  SGVector<float64_t> labels(num_vecs);
683  for (int32_t i=0;i<num_vecs;i++)
684  {
685  node_t* node=current;
686  SG_REF(node);
687  CDynamicObjectArray* children=node->get_children();
688  // while leaf node not reached
689  while(children->get_num_elements()>0)
690  {
691  // find feature class (or index of child node) of chosen vector in current node
692  int32_t index=-1;
693  for (int32_t j=0;j<(node->data.distinct_features).vlen;j++)
694  {
695  if (fmat(node->data.attribute_id,i)==node->data.distinct_features[j])
696  {
697  index=j;
698  break;
699  }
700  }
701 
702  if (index==-1)
703  break;
704 
705  CSGObject* el=children->get_element(node->data.feature_class[index]);
706  if (el!=NULL)
707  {
708  SG_UNREF(node);
709  node=dynamic_cast<node_t*>(el);
710  }
711  else
712  SG_ERROR("%d child is expected to be present. But it is NULL\n",index)
713 
714  SG_UNREF(children);
715  children=node->get_children();
716  }
717 
718  labels[i]=node->data.node_label;
719  SG_UNREF(node);
720  SG_UNREF(children);
721  }
722 
723  switch (get_machine_problem_type())
724  {
725  case PT_MULTICLASS:
726  return new CMulticlassLabels(labels);
727  case PT_REGRESSION:
728  return new CRegressionLabels(labels);
729  default:
730  SG_ERROR("Undefined problem type\n")
731  }
732 
733  return new CMulticlassLabels();
734 }
735 
736 bool CCHAIDTree::handle_missing_ordinal(SGVector<int32_t> cat, SGVector<float64_t> feats, SGVector<float64_t> labels,
737  SGVector<float64_t> weights)
738 {
739  // assimilate category indices other than missing (last cell of cat vector stores category index for missing)
740  // sanity check
741  REQUIRE(cat[cat.vlen-1]==cat.vlen-1,"last category is expected to be stored for MISSING. Hence it is expected to be un-merged\n")
743  for (int32_t i=0;i<cat.vlen-1;i++)
744  {
745  if (cat[i]==i)
746  cat_ind->push_back(i);
747  }
748 
749  // find most similar category to MISSING
750  float64_t max_pv_pair=CMath::MIN_REAL_NUMBER;
751  int32_t cindex_max=-1;
752  for (int32_t i=0;i<cat_ind->get_num_elements();i++)
753  {
755  for (int32_t j=0;j<feats.vlen;j++)
756  {
757  if ((feats[j]==cat_ind->get_element(i)) || feats[j]==MISSING)
758  feat_index->push_back(j);
759  }
760 
761  SGVector<float64_t> subfeats(feat_index->get_num_elements());
762  SGVector<float64_t> sublabels(feat_index->get_num_elements());
763  SGVector<float64_t> subweights(feat_index->get_num_elements());
764  for (int32_t j=0;j<feat_index->get_num_elements();j++)
765  {
766  subfeats[j]=feats[feat_index->get_element(j)];
767  sublabels[j]=labels[feat_index->get_element(j)];
768  subweights[j]=weights[feat_index->get_element(j)];
769  }
770 
771  float64_t pvalue=p_value(subfeats,sublabels,subweights);
772  if (pvalue>max_pv_pair)
773  {
774  max_pv_pair=pvalue;
775  cindex_max=cat_ind->get_element(i);
776  }
777 
778  SG_UNREF(feat_index);
779  }
780 
781  // compare if MISSING being merged is better than not being merged
782  SGVector<float64_t> feats_copy(feats.vlen);
783  for (int32_t i=0;i<feats.vlen;i++)
784  {
785  if (feats[i]==MISSING)
786  feats_copy[i]=cindex_max;
787  else
788  feats_copy[i]=feats[i];
789  }
790 
791  float64_t pv_merged=p_value(feats_copy, labels, weights);
792  float64_t pv_unmerged=p_value(feats, labels, weights);
793  if (pv_merged>pv_unmerged)
794  {
795  cat[cat.vlen-1]=cindex_max;
796  for (int32_t i=0;i<feats.vlen;i++)
797  {
798  if (feats[i]==MISSING)
799  feats[i]=cindex_max;
800  }
801 
802  return true;
803  }
804 
805  return false;
806 }
807 
808 float64_t CCHAIDTree::adjusted_p_value(float64_t up_value, int32_t inum_cat, int32_t fnum_cat, int32_t ft, bool is_missing)
809 {
810 
811  if (inum_cat==fnum_cat)
812  return up_value;
813 
814  switch (ft)
815  {
816  case 0:
817  {
818  float64_t sum=0.;
819  for (int32_t v=0;v<fnum_cat;v++)
820  {
821  float64_t lterm=inum_cat*CMath::log(fnum_cat-v);
822  for (int32_t j=1;j<=v;j++)
823  lterm-=CMath::log(j);
824 
825  for (int32_t j=1;j<=fnum_cat-v;j++)
826  lterm-=CMath::log(j);
827 
828  if (v%2==0)
829  sum+=CMath::exp(lterm);
830  else
831  sum-=CMath::exp(lterm);
832  }
833 
834  return sum*up_value;
835  }
836  case 1:
837  {
838  if (!is_missing)
839  return CMath::nchoosek(inum_cat-1,fnum_cat-1)*up_value;
840  else
841  return up_value*(CMath::nchoosek(inum_cat-2,fnum_cat-2)+fnum_cat*CMath::nchoosek(inum_cat-2,fnum_cat-1));
842  }
843  default:
844  SG_ERROR("Feature type must be either 0 (nominal) or 1 (ordinal). It is currently set as %d\n",ft)
845  }
846 
847  return 0.0;
848 }
849 
850 float64_t CCHAIDTree::p_value(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights)
851 {
852  switch (m_dependent_vartype)
853  {
854  case 0:
855  {
856  int32_t r=0;
857  int32_t c=0;
858  float64_t x2=pchi2_statistic(feat,labels,weights,r,c);
859  return 1-CStatistics::chi2_cdf(x2,(r-1)*(c-1));
860  }
861  case 1:
862  {
863  int32_t r=0;
864  int32_t c=0;
865  float64_t h2=likelihood_ratio_statistic(feat,labels,weights,r,c);
866  return 1-CStatistics::chi2_cdf(h2,(r-1));
867  }
868  case 2:
869  {
870  int32_t nf=feat.vlen;
871  int32_t num_cat=0;
872  float64_t f=anova_f_statistic(feat,labels,weights,num_cat);
873 
874  if (nf==num_cat)
875  return 1.0;
876 
877  return 1-CStatistics::fdistribution_cdf(f,num_cat-1,nf-num_cat);
878  }
879  default:
880  SG_ERROR("Dependent variable type must be either 0 or 1 or 2. It is currently set as %d\n",m_dependent_vartype)
881  }
882 
883  return -1.0;
884 }
885 
886 float64_t CCHAIDTree::anova_f_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights, int32_t &r)
887 {
888  // compute y_bar
889  float64_t y_bar=0.;
890  for (int32_t i=0;i<labels.vlen;i++)
891  y_bar+=labels[i]*weights[i];
892 
893  y_bar/=weights.sum(weights);
894 
895  SGVector<float64_t> ufeat=feat.clone();
896  r=ufeat.unique(ufeat.vector,ufeat.vlen);
897 
898  // compute y_i_bar
899  SGVector<float64_t> numer(r);
900  SGVector<float64_t> denom(r);
901  numer.zero();
902  denom.zero();
903  for (int32_t n=0;n<feat.vlen;n++)
904  {
905  for (int32_t i=0;i<r;i++)
906  {
907  if (feat[n]==ufeat[i])
908  {
909  numer[i]+=weights[n]*labels[n];
910  denom[i]+=weights[n];
911  break;
912  }
913  }
914  }
915 
916  // compute f statistic
917  float64_t nu=0.;
918  float64_t de=0.;
919  for (int32_t i=0;i<r;i++)
920  {
921  for (int32_t n=0;n<feat.vlen;n++)
922  {
923  if (feat[n]==ufeat[i])
924  {
925  nu+=weights[n]*CMath::pow(((numer[i]/denom[i])-y_bar),2);
926  de+=weights[n]*CMath::pow((labels[n]-(numer[i]/denom[i])),2);
927  }
928  }
929  }
930 
931  nu/=(r-1.0);
932  if (de==0)
933  return nu;
934 
935  de/=(feat.vlen-r-0.f);
936  return nu/de;
937 }
938 
939 float64_t CCHAIDTree::likelihood_ratio_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels,
940  SGVector<float64_t> weights, int32_t &r, int32_t &c)
941 {
942  SGVector<float64_t> ufeat=feat.clone();
943  SGVector<float64_t> ulabels=labels.clone();
944  r=ufeat.unique(ufeat.vector,ufeat.vlen);
945  c=ulabels.unique(ulabels.vector,ulabels.vlen);
946 
947  // contingency table, weight table
948  SGMatrix<int32_t> ct(r,c);
949  ct.zero();
950  SGMatrix<float64_t> wt(r,c);
951  wt.zero();
952  for (int32_t i=0;i<feat.vlen;i++)
953  {
954  // calculate row
955  int32_t row=-1;
956  for (int32_t j=0;j<r;j++)
957  {
958  if (feat[i]==ufeat[j])
959  {
960  row=j;
961  break;
962  }
963  }
964 
965  // calculate col
966  int32_t col=-1;
967  for (int32_t j=0;j<c;j++)
968  {
969  if (labels[i]==ulabels[j])
970  {
971  col=j;
972  break;
973  }
974  }
975 
976  ct(row,col)++;
977  wt(row,col)+=weights[i];
978  }
979 
980  SGMatrix<float64_t> expmat_indep=expected_cf_indep_model(ct,wt);
981 
982  SGVector<float64_t> score(c);
983  score.range_fill(1.0);
984  SGMatrix<float64_t> expmat_row_effects=expected_cf_row_effects_model(ct,wt,score);
985 
986  float64_t ret=0.;
987  for (int32_t i=0;i<r;i++)
988  {
989  for (int32_t j=0;j<c;j++)
990  ret+=expmat_row_effects(i,j)*CMath::log(expmat_row_effects(i,j)/expmat_indep(i,j));
991  }
992 
993  return 2*ret;
994 }
995 
996 float64_t CCHAIDTree::pchi2_statistic(SGVector<float64_t> feat, SGVector<float64_t> labels, SGVector<float64_t> weights,
997  int32_t &r, int32_t &c)
998 {
999  SGVector<float64_t> ufeat=feat.clone();
1000  SGVector<float64_t> ulabels=labels.clone();
1001  r=ufeat.unique(ufeat.vector,ufeat.vlen);
1002  c=ulabels.unique(ulabels.vector,ulabels.vlen);
1003 
1004  // contingency table, weight table
1005  SGMatrix<int32_t> ct(r,c);
1006  ct.zero();
1007  SGMatrix<float64_t> wt(r,c);
1008  wt.zero();
1009  for (int32_t i=0;i<feat.vlen;i++)
1010  {
1011  // calculate row
1012  int32_t row=-1;
1013  for (int32_t j=0;j<r;j++)
1014  {
1015  if (feat[i]==ufeat[j])
1016  {
1017  row=j;
1018  break;
1019  }
1020  }
1021 
1022  // calculate col
1023  int32_t col=-1;
1024  for (int32_t j=0;j<c;j++)
1025  {
1026  if (labels[i]==ulabels[j])
1027  {
1028  col=j;
1029  break;
1030  }
1031  }
1032 
1033  ct(row,col)++;
1034  wt(row,col)+=weights[i];
1035  }
1036 
1037  SGMatrix<float64_t> expected_cf=expected_cf_indep_model(ct,wt);
1038 
1039  float64_t ret=0.;
1040  for (int32_t i=0;i<r;i++)
1041  {
1042  for (int32_t j=0;j<c;j++)
1043  ret+=CMath::pow((ct(i,j)-expected_cf(i,j)),2)/expected_cf(i,j);
1044  }
1045 
1046  return ret;
1047 }
1048 
1049 SGMatrix<float64_t> CCHAIDTree::expected_cf_row_effects_model(SGMatrix<int32_t> ct, SGMatrix<float64_t> wt, SGVector<float64_t> score)
1050 {
1051  int32_t r=ct.num_rows;
1052  int32_t c=ct.num_cols;
1053 
1054  // compute row sum(n_i.'s) and column sum(n_.j's)
1055  SGVector<int32_t> row_sum(r);
1056  SGVector<int32_t> col_sum(c);
1057  for (int32_t i=0;i<r;i++)
1058  {
1059  int32_t sum=0;
1060  for (int32_t j=0;j<c;j++)
1061  sum+=ct(i,j);
1062 
1063  row_sum[i]=sum;
1064  }
1065  for (int32_t i=0;i<c;i++)
1066  {
1067  int32_t sum=0;
1068  for (int32_t j=0;j<r;j++)
1069  sum+=ct(j,i);
1070 
1071  col_sum[i]=sum;
1072  }
1073 
1074  // compute s_bar
1075  float64_t numer=0.;
1076  float64_t denom=0.;
1077  for (int32_t j=0;j<c;j++)
1078  {
1079  float64_t w_j=0.;
1080  for (int32_t i=0;i<r;i++)
1081  w_j+=wt(i,j);
1082 
1083  denom+=w_j;
1084  numer+=w_j*score[j];
1085  }
1086 
1087  float64_t s_bar=numer/denom;
1088 
1089  // element-wise normalize and invert weight matrix w_ij(new)=n_ij/w_ij(old)
1090  for (int32_t i=0;i<r;i++)
1091  {
1092  for (int32_t j=0;j<c;j++)
1093  wt(i,j)=(ct(i,j)-0.f)/wt(i,j);
1094  }
1095 
1096  SGMatrix<float64_t> m_k=wt.clone();
1097  SGVector<float64_t> alpha(r);
1098  SGVector<float64_t> beta(c);
1099  SGVector<float64_t> gamma(r);
1100  alpha.fill_vector(alpha.vector,alpha.vlen,1.0);
1101  beta.fill_vector(beta.vector,beta.vlen,1.0);
1102  gamma.fill_vector(gamma.vector,gamma.vlen,1.0);
1103  float64_t epsilon=1e-8;
1104  while(true)
1105  {
1106  // update alpha
1107  for (int32_t i=0;i<r;i++)
1108  {
1109  float64_t sum=0.;
1110  for (int32_t j=0;j<c;j++)
1111  sum+=m_k(i,j);
1112 
1113  alpha[i]*=(row_sum[i]-0.f)/sum;
1114  }
1115 
1116  // update beta
1117  for (int32_t j=0;j<c;j++)
1118  {
1119  float64_t sum=0.;
1120  for (int32_t i=0;i<r;i++)
1121  sum+=wt(i,j)*alpha[i]*CMath::pow(gamma[i],(score[j]-s_bar));
1122 
1123  beta[j]=(col_sum[j]-0.f)/sum;
1124  }
1125 
1126  // compute g_i for updating gamma
1127  SGVector<float64_t> g(r);
1128  SGMatrix<float64_t> m_star(r,c);
1129  for (int32_t i=0;i<r;i++)
1130  {
1131  for (int32_t j=0;j<c;j++)
1132  m_star(i,j)=wt(i,j)*alpha[i]*beta[j]*CMath::pow(gamma[i],score[j]-s_bar);
1133  }
1134 
1135  for (int32_t i=0;i<r;i++)
1136  {
1137  numer=0.;
1138  denom=0.;
1139  for (int32_t j=0;j<c;j++)
1140  {
1141  numer+=(score[j]-s_bar)*(ct(i,j)-m_star(i,j));
1142  denom+=CMath::pow((score[j]-s_bar),2)*m_star(i,j);
1143  }
1144 
1145  g[i]=1+numer/denom;
1146  }
1147 
1148  // update gamma
1149  for (int32_t i=0;i<r;i++)
1150  gamma[i]=(g[i]>0)?gamma[i]*g[i]:gamma[i];
1151 
1152  // update m_k
1153  SGMatrix<float64_t> m_kplus(r,c);
1154  float64_t max_diff=0.;
1155  for (int32_t i=0;i<r;i++)
1156  {
1157  for (int32_t j=0;j<c;j++)
1158  {
1159  m_kplus(i,j)=wt(i,j)*alpha[i]*beta[j]*CMath::pow(gamma[i],(score[j]-s_bar));
1160  float64_t abs_diff=CMath::abs(m_kplus(i,j)-m_k(i,j));
1161  if (abs_diff>max_diff)
1162  max_diff=abs_diff;
1163  }
1164  }
1165 
1166  m_k=m_kplus;
1167  if (max_diff<epsilon)
1168  break;
1169  }
1170 
1171  return m_k;
1172 }
1173 
1174 SGMatrix<float64_t> CCHAIDTree::expected_cf_indep_model(SGMatrix<int32_t> ct, SGMatrix<float64_t> wt)
1175 {
1176  int32_t r=ct.num_rows;
1177  int32_t c=ct.num_cols;
1178 
1179  // compute row sum(n_i.'s) and column sum(n_.j's)
1180  SGVector<int32_t> row_sum(r);
1181  SGVector<int32_t> col_sum(c);
1182  for (int32_t i=0;i<r;i++)
1183  {
1184  int32_t sum=0;
1185  for (int32_t j=0;j<c;j++)
1186  sum+=ct(i,j);
1187 
1188  row_sum[i]=sum;
1189  }
1190  for (int32_t i=0;i<c;i++)
1191  {
1192  int32_t sum=0;
1193  for (int32_t j=0;j<r;j++)
1194  sum+=ct(j,i);
1195 
1196  col_sum[i]=sum;
1197  }
1198 
1199  SGMatrix<float64_t> ret(r,c);
1200 
1201  // if all weights are 1 - m_ij=n_i.*n_.j/n..
1202  if (!m_weights_set)
1203  {
1204  int32_t total_sum=(r<=c)?row_sum.sum(row_sum):col_sum.sum(col_sum);
1205 
1206  for (int32_t i=0;i<r;i++)
1207  {
1208  for (int32_t j=0;j<c;j++)
1209  ret(i,j)=(row_sum[i]*col_sum[j]-0.f)/(total_sum-0.f);
1210  }
1211  }
1212  else
1213  {
1214  // element-wise normalize and invert weight matrix w_ij(new)=n_ij/w_ij(old)
1215  for (int32_t i=0;i<r;i++)
1216  {
1217  for (int32_t j=0;j<c;j++)
1218  wt(i,j)=(ct(i,j)-0.f)/wt(i,j);
1219  }
1220 
1221  // iteratively estimate mij
1222  SGMatrix<float64_t> m_k=wt.clone();
1223  SGVector<float64_t> alpha(r);
1224  SGVector<float64_t> beta(c);
1225  alpha.fill_vector(alpha.vector,alpha.vlen,1.0);
1226  beta.fill_vector(beta.vector,beta.vlen,1.0);
1227  float64_t epsilon=1e-8;
1228  while (true)
1229  {
1230  // update alpha
1231  for (int32_t i=0;i<r;i++)
1232  {
1233  float64_t sum=0.;
1234  for (int32_t j=0;j<c;j++)
1235  sum+=m_k(i,j);
1236 
1237  alpha[i]*=(row_sum[i]-0.f)/sum;
1238  }
1239 
1240  // update beta
1241  for (int32_t j=0;j<c;j++)
1242  {
1243  float64_t sum=0.;
1244  for (int32_t i=0;i<r;i++)
1245  sum+=wt(i,j)*alpha[i];
1246 
1247  beta[j]=(col_sum[j]-0.f)/sum;
1248  }
1249 
1250  // update m_k
1251  SGMatrix<float64_t> m_kplus(r,c);
1252  float64_t max_diff=0.0;
1253  for (int32_t i=0;i<r;i++)
1254  {
1255  for (int32_t j=0;j<c;j++)
1256  {
1257  m_kplus(i,j)=wt(i,j)*alpha[i]*beta[j];
1258  float64_t abs_diff=CMath::abs(m_kplus(i,j)-m_k(i,j));
1259  if (abs_diff>max_diff)
1260  max_diff=abs_diff;
1261  }
1262  }
1263 
1264  m_k=m_kplus;
1265 
1266  if (max_diff<epsilon)
1267  break;
1268  }
1269 
1270  ret=m_k;
1271  }
1272 
1273  return ret;
1274 }
1275 
1276 float64_t CCHAIDTree::sum_of_squared_deviation(SGVector<float64_t> lab, SGVector<float64_t> weights, float64_t &mean)
1277 {
1278  mean=0;
1279  float64_t total_weight=0;
1280  for (int32_t i=0;i<lab.vlen;i++)
1281  {
1282  mean+=lab[i]*weights[i];
1283  total_weight+=weights[i];
1284  }
1285 
1286  mean/=total_weight;
1287  float64_t dev=0;
1288  for (int32_t i=0;i<lab.vlen;i++)
1289  dev+=weights[i]*(lab[i]-mean)*(lab[i]-mean);
1290 
1291  return dev;
1292 }
1293 
1294 bool CCHAIDTree::continuous_to_ordinal(CDenseFeatures<float64_t>* feats)
1295 {
1296  // assimilate continuous breakpoints
1297  int32_t count_cont=0;
1298  for (int32_t i=0;i<feats->get_num_features();i++)
1299  {
1300  if (m_feature_types[i]==2)
1301  count_cont++;
1302  }
1303 
1304  if (count_cont==0)
1305  return false;
1306 
1307  REQUIRE(m_num_breakpoints>0,"Number of breakpoints for continuous to ordinal conversion not set.\n")
1308 
1309  SGVector<int32_t> cont_ind(count_cont);
1310  int32_t ci=0;
1311  for (int32_t i=0;i<feats->get_num_features();i++)
1312  {
1313  if (m_feature_types[i]==2)
1314  cont_ind[ci++]=i;
1315  }
1316 
1317  // form breakpoints matrix
1318  m_cont_breakpoints=SGMatrix<float64_t>(m_num_breakpoints,count_cont);
1319  int32_t bin_size=feats->get_num_vectors()/m_num_breakpoints;
1320  for (int32_t i=0;i<count_cont;i++)
1321  {
1322  int32_t left=feats->get_num_vectors()%m_num_breakpoints;
1323  int32_t end_pt=-1;
1324 
1325  SGVector<float64_t> values(feats->get_num_vectors());
1326  for (int32_t j=0;j<values.vlen;j++)
1327  values[j]=feats->get_feature_vector(j)[cont_ind[i]];
1328 
1329  values.qsort();
1330 
1331  for (int32_t j=0;j<m_num_breakpoints;j++)
1332  {
1333  if (left>0)
1334  {
1335  left--;
1336  end_pt+=bin_size+1;
1337  m_cont_breakpoints(j,i)=values[end_pt];
1338  }
1339  else
1340  {
1341  end_pt+=bin_size;
1342  m_cont_breakpoints(j,i)=values[end_pt];
1343  }
1344  }
1345  }
1346 
1347  // update data matrix
1348  modify_data_matrix(feats);
1349 
1350  return true;
1351 }
1352 
1353 void CCHAIDTree::modify_data_matrix(CDenseFeatures<float64_t>* feats)
1354 {
1355  int32_t c=0;
1356  for (int32_t i=0;i<feats->get_num_features();i++)
1357  {
1358  if (m_feature_types[i]!=2)
1359  continue;
1360 
1361  // continuous to ordinal conversion
1362  for (int32_t j=0;j<feats->get_num_vectors();j++)
1363  {
1364  for (int32_t k=0;k<m_num_breakpoints;k++)
1365  {
1366  if (feats->get_feature_vector(j)[i]<=m_cont_breakpoints(k,c))
1367  {
1368  feats->get_feature_vector(j)[i]=m_cont_breakpoints(k,c);
1369  break;
1370  }
1371  }
1372  }
1373 
1374  c++;
1375  }
1376 }
1377 
1378 void CCHAIDTree::init()
1379 {
1380  m_feature_types=SGVector<int32_t>();
1381  m_weights=SGVector<float64_t>();
1382  m_dependent_vartype=0;
1383  m_weights_set=false;
1384  m_max_tree_depth=0;
1385  m_min_node_size=0;
1386  m_alpha_merge=0.05;
1387  m_alpha_split=0.05;
1388  m_cont_breakpoints=SGMatrix<float64_t>();
1389  m_num_breakpoints=0;
1390 
1391  SG_ADD(&m_weights,"m_weights", "weights", MS_NOT_AVAILABLE);
1392  SG_ADD(&m_weights_set,"m_weights_set", "weights set", MS_NOT_AVAILABLE);
1393  SG_ADD(&m_feature_types,"m_feature_types", "feature types", MS_NOT_AVAILABLE);
1394  SG_ADD(&m_dependent_vartype,"m_dependent_vartype", "dependent variable type", MS_NOT_AVAILABLE);
1395  SG_ADD(&m_max_tree_depth,"m_max_tree_depth", "max tree depth", MS_NOT_AVAILABLE);
1396  SG_ADD(&m_min_node_size,"m_min_node_size", "min node size", MS_NOT_AVAILABLE);
1397  SG_ADD(&m_alpha_merge,"m_alpha_merge", "alpha-merge", MS_NOT_AVAILABLE);
1398  SG_ADD(&m_alpha_split,"m_alpha_split", "alpha-split", MS_NOT_AVAILABLE);
1399  SG_ADD(&m_cont_breakpoints,"m_cont_breakpoints", "breakpoints in continuous attributes", MS_NOT_AVAILABLE);
1400  SG_ADD(&m_num_breakpoints,"m_num_breakpoints", "number of breakpoints", MS_NOT_AVAILABLE);
1401 }

SHOGUN Machine Learning Toolbox - Documentation