SHOGUN  6.1.3
CARTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
34 
35 using namespace Eigen;
36 using namespace shogun;
37 
38 const float64_t CCARTree::MISSING=CMath::MAX_REAL_NUMBER;
39 const float64_t CCARTree::EQ_DELTA=1e-7;
40 const float64_t CCARTree::MIN_SPLIT_GAIN=1e-7;
41 
42 CCARTree::CCARTree()
44 {
45  init();
46 }
47 
48 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type)
50 {
51  init();
52  set_feature_types(attribute_types);
53  set_machine_problem_type(prob_type);
54 }
55 
56 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune)
58 {
59  init();
60  set_feature_types(attribute_types);
61  set_machine_problem_type(prob_type);
62  set_num_folds(num_folds);
63  if (cv_prune)
64  set_cv_pruning(cv_prune);
65 }
66 
68 {
70 }
71 
73 {
74  if (lab->get_label_type()==LT_MULTICLASS)
76  else if (lab->get_label_type()==LT_REGRESSION)
78  else
79  SG_ERROR("label type supplied is not supported\n")
80 
81  SG_REF(lab);
83  m_labels=lab;
84 }
85 
87 {
88  m_mode=mode;
89 }
90 
92 {
94  return true;
95  else if (m_mode==PT_REGRESSION && lab->get_label_type()==LT_REGRESSION)
96  return true;
97  else
98  return false;
99 }
100 
102 {
103  REQUIRE(data, "Data required for classification in apply_multiclass\n")
104 
105  // apply multiclass starting from root
106  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
107 
108  REQUIRE(current, "Tree machine not yet trained.\n");
109  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
110 
111  SG_UNREF(current);
112  return dynamic_cast<CMulticlassLabels*>(ret);
113 }
114 
116 {
117  REQUIRE(data, "Data required for classification in apply_multiclass\n")
118 
119  // apply regression starting from root
120  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
121  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
122 
123  SG_UNREF(current);
124  return dynamic_cast<CRegressionLabels*>(ret);
125 }
126 
128 {
129  if (weights.vlen==0)
130  {
131  weights=SGVector<float64_t>(feats->get_num_vectors());
132  weights.fill_vector(weights.vector,weights.vlen,1);
133  }
134 
135  CDynamicObjectArray* pruned_trees=prune_tree(this);
136 
137  int32_t min_index=0;
139  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
140  {
141  CSGObject* element=pruned_trees->get_element(i);
142  bnode_t* root=NULL;
143  if (element!=NULL)
144  root=dynamic_cast<bnode_t*>(element);
145  else
146  SG_ERROR("%d element is NULL\n",i);
147 
148  CLabels* labels=apply_from_current_node(feats, root);
149  float64_t error=compute_error(labels,gnd_truth,weights);
150  if (error<min_error)
151  {
152  min_index=i;
153  min_error=error;
154  }
155 
156  SG_UNREF(labels);
157  SG_UNREF(element);
158  }
159 
160  CSGObject* element=pruned_trees->get_element(min_index);
161  bnode_t* root=NULL;
162  if (element!=NULL)
163  root=dynamic_cast<bnode_t*>(element);
164  else
165  SG_ERROR("%d element is NULL\n",min_index);
166 
167  this->set_root(root);
168 
169  SG_UNREF(pruned_trees);
170  SG_UNREF(element);
171 }
172 
174 {
175  m_weights=w;
176  m_weights_set=true;
177 }
178 
180 {
181  return m_weights;
182 }
183 
185 {
187  m_weights_set=false;
188 }
189 
191 {
192  m_nominal=ft;
193  m_types_set=true;
194 }
195 
197 {
198  return m_nominal;
199 }
200 
202 {
204  m_types_set=false;
205 }
206 
207 int32_t CCARTree::get_num_folds() const
208 {
209  return m_folds;
210 }
211 
212 void CCARTree::set_num_folds(int32_t folds)
213 {
214  REQUIRE(folds>1,"Number of folds is expected to be greater than 1. Supplied value is %d\n",folds)
215  m_folds=folds;
216 }
217 
218 int32_t CCARTree::get_max_depth() const
219 {
220  return m_max_depth;
221 }
222 
223 void CCARTree::set_max_depth(int32_t depth)
224 {
225  REQUIRE(depth>0,"Max allowed tree depth should be greater than 0. Supplied value is %d\n",depth)
226  m_max_depth=depth;
227 }
228 
230 {
231  return m_min_node_size;
232 }
233 
234 void CCARTree::set_min_node_size(int32_t nsize)
235 {
236  REQUIRE(nsize>0,"Min allowed node size should be greater than 0. Supplied value is %d\n",nsize)
237  m_min_node_size=nsize;
238 }
239 
241 {
242  REQUIRE(ep>=0,"Input epsilon value is expected to be greater than or equal to 0\n")
243  m_label_epsilon=ep;
244 }
245 
247 {
248  REQUIRE(data,"Data required for training\n")
249  REQUIRE(data->get_feature_class()==C_DENSE,"Dense data required for training\n")
250 
251  int32_t num_features=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_features();
252  int32_t num_vectors=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_vectors();
253 
254  if (m_weights_set)
255  {
256  REQUIRE(m_weights.vlen==num_vectors,"Length of weights vector (currently %d) should be same as"
257  " number of vectors in data (presently %d)",m_weights.vlen,num_vectors)
258  }
259  else
260  {
261  // all weights are equal to 1
262  m_weights=SGVector<float64_t>(num_vectors);
264  }
265 
266  if (m_types_set)
267  {
268  REQUIRE(m_nominal.vlen==num_features,"Length of m_nominal vector (currently %d) should "
269  "be same as number of features in data (presently %d)",m_nominal.vlen,num_features)
270  }
271  else
272  {
273  SG_WARNING("Feature types are not specified. All features are considered as continuous in training")
274  m_nominal=SGVector<bool>(num_features);
276  }
277 
279 
280  if (m_apply_cv_pruning)
281  {
282  CDenseFeatures<float64_t>* feats=dynamic_cast<CDenseFeatures<float64_t>*>(data);
284  }
285 
286  return true;
287 }
288 
290 {
291  m_pre_sort=true;
292  m_sorted_features=sorted_feats;
293  m_sorted_indices=sorted_indices;
294 }
295 
297 {
298  SGMatrix<float64_t> mat=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_feature_matrix();
299  sorted_feats = SGMatrix<float64_t>(mat.num_cols, mat.num_rows);
300  sorted_indices = SGMatrix<index_t>(mat.num_cols, mat.num_rows);
301  for(int32_t i=0; i<sorted_indices.num_cols; i++)
302  for(int32_t j=0; j<sorted_indices.num_rows; j++)
303  sorted_indices(j,i)=j;
304 
305  Map<MatrixXd> map_sorted_feats(sorted_feats.matrix, mat.num_cols, mat.num_rows);
306  Map<MatrixXd> map_data(mat.matrix, mat.num_rows, mat.num_cols);
307 
308  map_sorted_feats=map_data.transpose();
309 
310  #pragma omp parallel for
311  for(int32_t i=0; i<sorted_feats.num_cols; i++)
312  CMath::qsort_index(sorted_feats.get_column_vector(i), sorted_indices.get_column_vector(i), sorted_feats.num_rows);
313 
314 }
315 
317 {
318  REQUIRE(labels,"labels have to be supplied\n");
319  REQUIRE(data,"data matrix has to be supplied\n");
320 
321  bnode_t* node=new bnode_t();
322  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
323  SGMatrix<float64_t> mat=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_feature_matrix();
324  int32_t num_feats=mat.num_rows;
325  int32_t num_vecs=mat.num_cols;
326 
327  // calculate node label
328  switch(m_mode)
329  {
330  case PT_REGRESSION:
331  {
332  float64_t sum=0;
333  for (int32_t i=0;i<labels_vec.vlen;i++)
334  sum+=labels_vec[i]*weights[i];
335 
336  // lsd*total_weight=sum_of_squared_deviation
337  float64_t tot=0;
338  node->data.weight_minus_node=tot*least_squares_deviation(labels_vec,weights,tot);
339  node->data.node_label=sum/tot;
340  node->data.total_weight=tot;
341 
342  break;
343  }
344  case PT_MULTICLASS:
345  {
346  SGVector<float64_t> lab=labels_vec.clone();
347  CMath::qsort(lab);
348  // stores max total weight for a single label
349  int32_t max=weights[0];
350  // stores one of the indices having max total weight
351  int32_t maxi=0;
352  int32_t c=weights[0];
353  for (int32_t i=1;i<lab.vlen;i++)
354  {
355  if (lab[i]==lab[i-1])
356  {
357  c+=weights[i];
358  }
359  else if (c>max)
360  {
361  max=c;
362  maxi=i-1;
363  c=weights[i];
364  }
365  else
366  {
367  c=weights[i];
368  }
369  }
370 
371  if (c>max)
372  {
373  max=c;
374  maxi=lab.vlen-1;
375  }
376 
377  node->data.node_label=lab[maxi];
378 
379  // resubstitution error calculation
380  node->data.total_weight=weights.sum(weights);
381  node->data.weight_minus_node=node->data.total_weight-max;
382  break;
383  }
384  default :
385  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
386  }
387 
388  // check stopping rules
389  // case 1 : max tree depth reached if max_depth set
390  if ((m_max_depth>0) && (level==m_max_depth))
391  {
392  node->data.num_leaves=1;
393  node->data.weight_minus_branch=node->data.weight_minus_node;
394  return node;
395  }
396 
397  // case 2 : min node size violated if min_node_size specified
398  if ((m_min_node_size>1) && (labels_vec.vlen<=m_min_node_size))
399  {
400  node->data.num_leaves=1;
401  node->data.weight_minus_branch=node->data.weight_minus_node;
402  return node;
403  }
404 
405  // choose best attribute
406  // transit_into_values for left child
407  SGVector<float64_t> left(num_feats);
408  // transit_into_values for right child
409  SGVector<float64_t> right(num_feats);
410  // final data distribution among children
411  SGVector<bool> left_final(num_vecs);
412  int32_t num_missing_final=0;
413  int32_t c_left=-1;
414  int32_t c_right=-1;
415  int32_t best_attribute;
416 
417  SGVector<index_t> indices(num_vecs);
418  if (m_pre_sort)
419  {
420  CSubsetStack* subset_stack = data->get_subset_stack();
421  if (subset_stack->has_subsets())
422  indices=(subset_stack->get_last_subset())->get_subset_idx();
423  else
424  indices.range_fill();
425  SG_UNREF(subset_stack);
426  best_attribute=compute_best_attribute(m_sorted_features,weights,labels,left,right,left_final,num_missing_final,c_left,c_right,0,indices);
427  }
428  else
429  best_attribute=compute_best_attribute(mat,weights,labels,left,right,left_final,num_missing_final,c_left,c_right);
430 
431  if (best_attribute==-1)
432  {
433  node->data.num_leaves=1;
434  node->data.weight_minus_branch=node->data.weight_minus_node;
435  return node;
436  }
437 
438  SGVector<float64_t> left_transit(c_left);
439  SGVector<float64_t> right_transit(c_right);
440  sg_memcpy(left_transit.vector,left.vector,c_left*sizeof(float64_t));
441  sg_memcpy(right_transit.vector,right.vector,c_right*sizeof(float64_t));
442 
443  if (num_missing_final>0)
444  {
445  SGVector<bool> is_left_final(num_vecs-num_missing_final);
446  int32_t ilf=0;
447  for (int32_t i=0;i<num_vecs;i++)
448  {
449  if (mat(best_attribute,i)!=MISSING)
450  is_left_final[ilf++]=left_final[i];
451  }
452 
453  left_final=surrogate_split(mat,weights,is_left_final,best_attribute);
454  }
455 
456  int32_t count_left=0;
457  for (int32_t c=0;c<num_vecs;c++)
458  count_left=(left_final[c])?count_left+1:count_left;
459 
460  SGVector<index_t> subsetl(count_left);
461  SGVector<float64_t> weightsl(count_left);
462  SGVector<index_t> subsetr(num_vecs-count_left);
463  SGVector<float64_t> weightsr(num_vecs-count_left);
464  index_t l=0;
465  index_t r=0;
466  for (int32_t c=0;c<num_vecs;c++)
467  {
468  if (left_final[c])
469  {
470  subsetl[l]=c;
471  weightsl[l++]=weights[c];
472  }
473  else
474  {
475  subsetr[r]=c;
476  weightsr[r++]=weights[c];
477  }
478  }
479 
480  // left child
481  data->add_subset(subsetl);
482  labels->add_subset(subsetl);
483  bnode_t* left_child=CARTtrain(data,weightsl,labels,level+1);
484  data->remove_subset();
485  labels->remove_subset();
486 
487  // right child
488  data->add_subset(subsetr);
489  labels->add_subset(subsetr);
490  bnode_t* right_child=CARTtrain(data,weightsr,labels,level+1);
491  data->remove_subset();
492  labels->remove_subset();
493 
494  // set node parameters
495  node->data.attribute_id=best_attribute;
496  node->left(left_child);
497  node->right(right_child);
498  left_child->data.transit_into_values=left_transit;
499  right_child->data.transit_into_values=right_transit;
500  node->data.num_leaves=left_child->data.num_leaves+right_child->data.num_leaves;
501  node->data.weight_minus_branch=left_child->data.weight_minus_branch+right_child->data.weight_minus_branch;
502 
503  return node;
504 }
505 
507 {
508  float64_t delta=0;
509  if (m_mode==PT_REGRESSION)
510  delta=m_label_epsilon;
511 
512  SGVector<float64_t> ulabels(labels_vec.vlen);
513  SGVector<index_t> sidx=CMath::argsort(labels_vec);
514  ulabels[0]=labels_vec[sidx[0]];
515  n_ulabels=1;
516  int32_t start=0;
517  for (int32_t i=1;i<sidx.vlen;i++)
518  {
519  if (labels_vec[sidx[i]]<=labels_vec[sidx[start]]+delta)
520  continue;
521 
522  start=i;
523  ulabels[n_ulabels]=labels_vec[sidx[i]];
524  n_ulabels++;
525  }
526 
527  return ulabels;
528 }
529 
531  SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing_final, int32_t &count_left,
532  int32_t &count_right, int32_t subset_size, const SGVector<index_t>& active_indices)
533 {
534  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
535  int32_t num_vecs=labels->get_num_labels();
536  int32_t num_feats;
537  if (m_pre_sort)
538  num_feats=mat.num_cols;
539  else
540  num_feats=mat.num_rows;
541 
542  int32_t n_ulabels;
543  SGVector<float64_t> ulabels=get_unique_labels(labels_vec,n_ulabels);
544 
545  // if all labels same early stop
546  if (n_ulabels==1)
547  return -1;
548 
549  float64_t delta=0;
550  if (m_mode==PT_REGRESSION)
551  delta=m_label_epsilon;
552 
553  SGVector<float64_t> total_wclasses(n_ulabels);
554  total_wclasses.zero();
555 
556  SGVector<int32_t> simple_labels(num_vecs);
557  for (int32_t i=0;i<num_vecs;i++)
558  {
559  for (int32_t j=0;j<n_ulabels;j++)
560  {
561  if (CMath::abs(labels_vec[i]-ulabels[j])<=delta)
562  {
563  simple_labels[i]=j;
564  total_wclasses[j]+=weights[i];
565  break;
566  }
567  }
568  }
569 
570  SGVector<index_t> idx(num_feats);
571  idx.range_fill();
572  if (subset_size)
573  {
574  num_feats=subset_size;
575  CMath::permute(idx);
576  }
577 
578  float64_t max_gain=MIN_SPLIT_GAIN;
579  int32_t best_attribute=-1;
580  float64_t best_threshold=0;
581 
582  SGVector<int64_t> indices_mask;
583  SGVector<int32_t> count_indices(mat.num_rows);
584  count_indices.zero();
585  SGVector<int32_t> dupes(num_vecs);
586  dupes.range_fill();
587  if (m_pre_sort)
588  {
589  indices_mask = SGVector<int64_t>(mat.num_rows);
590  indices_mask.set_const(-1);
591  for(int32_t j=0;j<active_indices.size();j++)
592  {
593  if (indices_mask[active_indices[j]]>=0)
594  dupes[indices_mask[active_indices[j]]]=j;
595 
596  indices_mask[active_indices[j]]=j;
597  count_indices[active_indices[j]]++;
598  }
599  }
600 
601  for (int32_t i=0;i<num_feats;i++)
602  {
603  SGVector<float64_t> feats(num_vecs);
604  SGVector<index_t> sorted_args(num_vecs);
605  SGVector<int32_t> temp_count_indices(count_indices.size());
606  sg_memcpy(temp_count_indices.vector, count_indices.vector, sizeof(int32_t)*count_indices.size());
607 
608  if (m_pre_sort)
609  {
610  SGVector<float64_t> temp_col(mat.get_column_vector(idx[i]), mat.num_rows, false);
611  SGVector<index_t> sorted_indices(m_sorted_indices.get_column_vector(idx[i]), mat.num_rows, false);
612  int32_t count=0;
613  for(int32_t j=0;j<mat.num_rows;j++)
614  {
615  if (indices_mask[sorted_indices[j]]>=0)
616  {
617  int32_t count_index = count_indices[sorted_indices[j]];
618  while(count_index>0)
619  {
620  feats[count]=temp_col[j];
621  sorted_args[count]=indices_mask[sorted_indices[j]];
622  ++count;
623  --count_index;
624  }
625  if (count==num_vecs)
626  break;
627  }
628  }
629  }
630  else
631  {
632  for (int32_t j=0;j<num_vecs;j++)
633  feats[j]=mat(idx[i],j);
634 
635  // O(N*logN)
636  sorted_args.range_fill();
637  CMath::qsort_index(feats.vector, sorted_args.vector, feats.size());
638  }
639  int32_t n_nm_vecs=feats.vlen;
640  // number of non-missing vecs
641  while (feats[n_nm_vecs-1]==MISSING)
642  {
643  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]-=weights[sorted_args[n_nm_vecs-1]];
644  n_nm_vecs--;
645  }
646 
647  // if only one unique value - it cannot be used to split
648  if (feats[n_nm_vecs-1]<=feats[0]+EQ_DELTA)
649  continue;
650 
651  if (m_nominal[idx[i]])
652  {
653  SGVector<int32_t> simple_feats(num_vecs);
654  simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1);
655 
656  // convert to simple values
657  simple_feats[0]=0;
658  int32_t c=0;
659  for (int32_t j=1;j<n_nm_vecs;j++)
660  {
661  if (feats[j]==feats[j-1])
662  simple_feats[j]=c;
663  else
664  simple_feats[j]=(++c);
665  }
666 
667  SGVector<float64_t> ufeats(c+1);
668  ufeats[0]=feats[0];
669  int32_t u=0;
670  for (int32_t j=1;j<n_nm_vecs;j++)
671  {
672  if (feats[j]==feats[j-1])
673  continue;
674  else
675  ufeats[++u]=feats[j];
676  }
677 
678  // test all 2^(I-1)-1 possible division between two nodes
679  int32_t num_cases=CMath::pow(2,c);
680  for (int32_t k=1;k<num_cases;k++)
681  {
682  SGVector<float64_t> wleft(n_ulabels);
683  SGVector<float64_t> wright(n_ulabels);
684  wleft.zero();
685  wright.zero();
686 
687  // stores which vectors are assigned to left child
688  SGVector<bool> is_left(num_vecs);
689  is_left.fill_vector(is_left.vector,is_left.vlen,false);
690 
691  // stores which among the categorical values of chosen attribute are assigned left child
692  SGVector<bool> feats_left(c+1);
693 
694  // fill feats_left in a unique way corresponding to the case
695  for (int32_t p=0;p<c+1;p++)
696  feats_left[p]=((k/CMath::pow(2,p))%(CMath::pow(2,p+1))==1);
697 
698  // form is_left
699  for (int32_t j=0;j<n_nm_vecs;j++)
700  {
701  is_left[sorted_args[j]]=feats_left[simple_feats[j]];
702  if (is_left[sorted_args[j]])
703  wleft[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
704  else
705  wright[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
706  }
707  for (int32_t j=n_nm_vecs-1;j>=0;j--)
708  {
709  if(dupes[j]!=j)
710  is_left[j]=is_left[dupes[j]];
711  }
712 
713  float64_t g=0;
714  if (m_mode==PT_MULTICLASS)
715  g=gain(wleft,wright,total_wclasses);
716  else if (m_mode==PT_REGRESSION)
717  g=gain(wleft,wright,total_wclasses,ulabels);
718  else
719  SG_ERROR("Undefined problem statement\n");
720 
721  if (g>max_gain)
722  {
723  best_attribute=idx[i];
724  max_gain=g;
725  sg_memcpy(is_left_final.vector,is_left.vector,is_left.vlen*sizeof(bool));
726  num_missing_final=num_vecs-n_nm_vecs;
727 
728  count_left=0;
729  for (int32_t l=0;l<c+1;l++)
730  count_left=(feats_left[l])?count_left+1:count_left;
731 
732  count_right=c+1-count_left;
733 
734  int32_t l=0;
735  int32_t r=0;
736  for (int32_t w=0;w<c+1;w++)
737  {
738  if (feats_left[w])
739  left[l++]=ufeats[w];
740  else
741  right[r++]=ufeats[w];
742  }
743  }
744  }
745  }
746  else
747  {
748  // O(N)
749  SGVector<float64_t> right_wclasses=total_wclasses.clone();
750  SGVector<float64_t> left_wclasses(n_ulabels);
751  left_wclasses.zero();
752 
753  // O(N)
754  // find best split for non-nominal attribute - choose threshold (z)
755  float64_t z=feats[0];
756  right_wclasses[simple_labels[sorted_args[0]]]-=weights[sorted_args[0]];
757  left_wclasses[simple_labels[sorted_args[0]]]+=weights[sorted_args[0]];
758  for (int32_t j=1;j<n_nm_vecs;j++)
759  {
760  if (feats[j]<=z+EQ_DELTA)
761  {
762  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
763  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
764  continue;
765  }
766  // O(F)
767  float64_t g=0;
768  if (m_mode==PT_MULTICLASS)
769  g=gain(left_wclasses,right_wclasses,total_wclasses);
770  else if (m_mode==PT_REGRESSION)
771  g=gain(left_wclasses,right_wclasses,total_wclasses,ulabels);
772  else
773  SG_ERROR("Undefined problem statement\n");
774 
775  if (g>max_gain)
776  {
777  max_gain=g;
778  best_attribute=idx[i];
779  best_threshold=z;
780  num_missing_final=num_vecs-n_nm_vecs;
781  }
782 
783  z=feats[j];
784  if (feats[n_nm_vecs-1]<=z+EQ_DELTA)
785  break;
786  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
787  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
788  }
789  }
790 
791  // restore total_wclasses
792  while (n_nm_vecs<feats.vlen)
793  {
794  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]+=weights[sorted_args[n_nm_vecs-1]];
795  n_nm_vecs++;
796  }
797  }
798 
799  if (best_attribute==-1)
800  return -1;
801 
802  if (!m_nominal[best_attribute])
803  {
804  left[0]=best_threshold;
805  right[0]=best_threshold;
806  count_left=1;
807  count_right=1;
808  if (m_pre_sort)
809  {
810  SGVector<float64_t> temp_vec(mat.get_column_vector(best_attribute), mat.num_rows, false);
811  SGVector<index_t> sorted_indices(m_sorted_indices.get_column_vector(best_attribute), mat.num_rows, false);
812  int32_t count=0;
813  for(int32_t i=0;i<mat.num_rows;i++)
814  {
815  if (indices_mask[sorted_indices[i]]>=0)
816  {
817  is_left_final[indices_mask[sorted_indices[i]]]=(temp_vec[i]<=best_threshold);
818  ++count;
819  if (count==num_vecs)
820  break;
821  }
822  }
823  for (int32_t i=num_vecs-1;i>=0;i--)
824  {
825  if(dupes[i]!=i)
826  is_left_final[i]=is_left_final[dupes[i]];
827  }
828 
829  }
830  else
831  {
832  for (int32_t i=0;i<num_vecs;i++)
833  is_left_final[i]=(mat(best_attribute,i)<=best_threshold);
834  }
835  }
836 
837  return best_attribute;
838 }
839 
841 {
842  // return vector - left/right belongingness
843  SGVector<bool> ret(m.num_cols);
844 
845  // ditribute data with known attributes
846  int32_t l=0;
847  float64_t p_l=0.;
848  float64_t total=0.;
849  // stores indices of vectors with missing attribute
851  // stores lambda values corresponding to missing vectors - initialized all with 0
852  CDynamicArray<float64_t>* association_index=new CDynamicArray<float64_t>();
853  for (int32_t i=0;i<m.num_cols;i++)
854  {
855  if (!CMath::fequals(m(attr,i),MISSING,0))
856  {
857  ret[i]=nm_left[l];
858  total+=weights[i];
859  if (nm_left[l++])
860  p_l+=weights[i];
861  }
862  else
863  {
864  missing_vecs->push_back(i);
865  association_index->push_back(0.);
866  }
867  }
868 
869  // for lambda calculation
870  float64_t p_r=(total-p_l)/total;
871  p_l/=total;
872  float64_t p=CMath::min(p_r,p_l);
873 
874  // for each attribute (X') alternative to best split (X)
875  for (int32_t i=0;i<m.num_rows;i++)
876  {
877  if (i==attr)
878  continue;
879 
880  // find set of vectors with non-missing values for both X and X'
881  CDynamicArray<int32_t>* intersect_vecs=new CDynamicArray<int32_t>();
882  for (int32_t j=0;j<m.num_cols;j++)
883  {
884  if (!(CMath::fequals(m(i,j),MISSING,0) || CMath::fequals(m(attr,j),MISSING,0)))
885  intersect_vecs->push_back(j);
886  }
887 
888  if (intersect_vecs->get_num_elements()==0)
889  {
890  SG_UNREF(intersect_vecs);
891  continue;
892  }
893 
894 
895  if (m_nominal[i])
896  handle_missing_vecs_for_nominal_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
897  else
898  handle_missing_vecs_for_continuous_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
899 
900  SG_UNREF(intersect_vecs);
901  }
902 
903  // if some missing attribute vectors are yet not addressed, use majority rule
904  for (int32_t i=0;i<association_index->get_num_elements();i++)
905  {
906  if (association_index->get_element(i)==0.)
907  ret[missing_vecs->get_element(i)]=(p_l>=p_r);
908  }
909 
910  SG_UNREF(missing_vecs);
911  SG_UNREF(association_index);
912  return ret;
913 }
914 
916  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
917  SGVector<float64_t> weights, float64_t p, int32_t attr)
918 {
919  // for lambda calculation - total weight of all vectors in X intersect X'
920  float64_t denom=0.;
921  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
922  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
923  {
924  feats[j]=m(attr,intersect_vecs->get_element(j));
925  denom+=weights[intersect_vecs->get_element(j)];
926  }
927 
928  // unique feature values for X'
929  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
930 
931 
932  // all possible splits for chosen attribute
933  for (int32_t j=0;j<num_unique-1;j++)
934  {
935  float64_t z=feats[j];
936  float64_t numer=0.;
937  float64_t numerc=0.;
938  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
939  {
940  // if both go left or both go right
941  if ((m(attr,intersect_vecs->get_element(k))<=z) && is_left[intersect_vecs->get_element(k)])
942  numer+=weights[intersect_vecs->get_element(k)];
943  else if ((m(attr,intersect_vecs->get_element(k))>z) && !is_left[intersect_vecs->get_element(k)])
944  numer+=weights[intersect_vecs->get_element(k)];
945  // complementary split cases - one goes left other right
946  else if ((m(attr,intersect_vecs->get_element(k))<=z) && !is_left[intersect_vecs->get_element(k)])
947  numerc+=weights[intersect_vecs->get_element(k)];
948  else if ((m(attr,intersect_vecs->get_element(k))>z) && is_left[intersect_vecs->get_element(k)])
949  numerc+=weights[intersect_vecs->get_element(k)];
950  }
951 
952  float64_t lambda=0.;
953  if (numer>=numerc)
954  lambda=(p-(1-numer/denom))/p;
955  else
956  lambda=(p-(1-numerc/denom))/p;
957  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
958  {
959  if ((lambda>association_index->get_element(k)) &&
960  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
961  {
962  association_index->set_element(lambda,k);
963  if (numer>=numerc)
964  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))<=z);
965  else
966  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))>z);
967  }
968  }
969  }
970 }
971 
973  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
974  SGVector<float64_t> weights, float64_t p, int32_t attr)
975 {
976  // for lambda calculation - total weight of all vectors in X intersect X'
977  float64_t denom=0.;
978  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
979  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
980  {
981  feats[j]=m(attr,intersect_vecs->get_element(j));
982  denom+=weights[intersect_vecs->get_element(j)];
983  }
984 
985  // unique feature values for X'
986  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
987 
988  // scan all splits for chosen alternative attribute X'
989  int32_t num_cases=CMath::pow(2,(num_unique-1));
990  for (int32_t j=1;j<num_cases;j++)
991  {
992  SGVector<bool> feats_left(num_unique);
993  for (int32_t k=0;k<num_unique;k++)
994  feats_left[k]=((j/CMath::pow(2,k))%(CMath::pow(2,k+1))==1);
995 
996  SGVector<bool> intersect_vecs_left(intersect_vecs->get_num_elements());
997  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
998  {
999  for (int32_t q=0;q<num_unique;q++)
1000  {
1001  if (feats[q]==m(attr,intersect_vecs->get_element(k)))
1002  {
1003  intersect_vecs_left[k]=feats_left[q];
1004  break;
1005  }
1006  }
1007  }
1008 
1009  float64_t numer=0.;
1010  float64_t numerc=0.;
1011  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
1012  {
1013  // if both go left or both go right
1014  if (intersect_vecs_left[k]==is_left[intersect_vecs->get_element(k)])
1015  numer+=weights[intersect_vecs->get_element(k)];
1016  else
1017  numerc+=weights[intersect_vecs->get_element(k)];
1018  }
1019 
1020  // lambda for this split (2 case identical split/complementary split)
1021  float64_t lambda=0.;
1022  if (numer>=numerc)
1023  lambda=(p-(1-numer/denom))/p;
1024  else
1025  lambda=(p-(1-numerc/denom))/p;
1026 
1027  // address missing value vectors not yet addressed or addressed using worse split
1028  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
1029  {
1030  if ((lambda>association_index->get_element(k)) &&
1031  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
1032  {
1033  association_index->set_element(lambda,k);
1034  // decide left/right based on which feature value the chosen data point has
1035  for (int32_t q=0;q<num_unique;q++)
1036  {
1037  if (feats[q]==m(attr,missing_vecs->get_element(k)))
1038  {
1039  if (numer>=numerc)
1040  is_left[missing_vecs->get_element(k)]=feats_left[q];
1041  else
1042  is_left[missing_vecs->get_element(k)]=!feats_left[q];
1043 
1044  break;
1045  }
1046  }
1047  }
1048  }
1049  }
1050 }
1051 
1053  SGVector<float64_t> feats)
1054 {
1055  float64_t total_lweight=0;
1056  float64_t total_rweight=0;
1057  float64_t total_weight=0;
1058 
1059  float64_t lsd_n=least_squares_deviation(feats,wtotal,total_weight);
1060  float64_t lsd_l=least_squares_deviation(feats,wleft,total_lweight);
1061  float64_t lsd_r=least_squares_deviation(feats,wright,total_rweight);
1062 
1063  return lsd_n-(lsd_l*(total_lweight/total_weight))-(lsd_r*(total_rweight/total_weight));
1064 }
1065 
1067 {
1068  float64_t total_lweight=0;
1069  float64_t total_rweight=0;
1070  float64_t total_weight=0;
1071 
1072  float64_t gini_n=gini_impurity_index(wtotal,total_weight);
1073  float64_t gini_l=gini_impurity_index(wleft,total_lweight);
1074  float64_t gini_r=gini_impurity_index(wright,total_rweight);
1075  return gini_n-(gini_l*(total_lweight/total_weight))-(gini_r*(total_rweight/total_weight));
1076 }
1077 
1078 float64_t CCARTree::gini_impurity_index(const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight)
1079 {
1080  Map<VectorXd> map_weighted_lab_classes(weighted_lab_classes.vector, weighted_lab_classes.size());
1081  total_weight=map_weighted_lab_classes.sum();
1082  float64_t gini=map_weighted_lab_classes.dot(map_weighted_lab_classes);
1083 
1084  gini=1.0-(gini/(total_weight*total_weight));
1085  return gini;
1086 }
1087 
1089 {
1090 
1091  Map<VectorXd> map_weights(weights.vector, weights.size());
1092  Map<VectorXd> map_feats(feats.vector, weights.size());
1093  float64_t mean=map_weights.dot(map_feats);
1094  total_weight=map_weights.sum();
1095 
1096  mean/=total_weight;
1097  float64_t dev=0;
1098  for (int32_t i=0;i<weights.vlen;i++)
1099  dev+=weights[i]*(feats[i]-mean)*(feats[i]-mean);
1100 
1101  return dev/total_weight;
1102 }
1103 
1105 {
1106  int32_t num_vecs=feats->get_num_vectors();
1107  REQUIRE(num_vecs>0, "No data provided in apply\n");
1108 
1109  SGVector<float64_t> labels(num_vecs);
1110  for (int32_t i=0;i<num_vecs;i++)
1111  {
1112  SGVector<float64_t> sample=feats->get_feature_vector(i);
1113  bnode_t* node=current;
1114  SG_REF(node);
1115 
1116  // until leaf is reached
1117  while(node->data.num_leaves!=1)
1118  {
1119  bnode_t* leftchild=node->left();
1120 
1121  if (m_nominal[node->data.attribute_id])
1122  {
1123  SGVector<float64_t> comp=leftchild->data.transit_into_values;
1124  bool flag=false;
1125  for (int32_t k=0;k<comp.vlen;k++)
1126  {
1127  if (comp[k]==sample[node->data.attribute_id])
1128  {
1129  flag=true;
1130  break;
1131  }
1132  }
1133 
1134  if (flag)
1135  {
1136  SG_UNREF(node);
1137  node=leftchild;
1138  SG_REF(leftchild);
1139  }
1140  else
1141  {
1142  SG_UNREF(node);
1143  node=node->right();
1144  }
1145  }
1146  else
1147  {
1148  if (sample[node->data.attribute_id]<=leftchild->data.transit_into_values[0])
1149  {
1150  SG_UNREF(node);
1151  node=leftchild;
1152  SG_REF(leftchild);
1153  }
1154  else
1155  {
1156  SG_UNREF(node);
1157  node=node->right();
1158  }
1159  }
1160 
1161  SG_UNREF(leftchild);
1162  }
1163 
1164  labels[i]=node->data.node_label;
1165  SG_UNREF(node);
1166  }
1167 
1168  switch(m_mode)
1169  {
1170  case PT_MULTICLASS:
1171  {
1172  CMulticlassLabels* mlabels=new CMulticlassLabels(labels);
1173  return mlabels;
1174  }
1175 
1176  case PT_REGRESSION:
1177  {
1178  CRegressionLabels* rlabels=new CRegressionLabels(labels);
1179  return rlabels;
1180  }
1181 
1182  default:
1183  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
1184  }
1185 
1186  return NULL;
1187 }
1188 
1190 {
1191  int32_t num_vecs=data->get_num_vectors();
1192 
1193  // divide data into V folds randomly
1194  SGVector<int32_t> subid(num_vecs);
1195  subid.random_vector(subid.vector,subid.vlen,0,folds-1);
1196 
1197  // for each fold subset
1200  SGVector<int32_t> num_alphak(folds);
1201  for (int32_t i=0;i<folds;i++)
1202  {
1203  // for chosen fold, create subset for training parameters
1204  CDynamicArray<int32_t>* test_indices=new CDynamicArray<int32_t>();
1205  CDynamicArray<int32_t>* train_indices=new CDynamicArray<int32_t>();
1206  for (int32_t j=0;j<num_vecs;j++)
1207  {
1208  if (subid[j]==i)
1209  test_indices->push_back(j);
1210  else
1211  train_indices->push_back(j);
1212  }
1213 
1214  if (test_indices->get_num_elements()==0 || train_indices->get_num_elements()==0)
1215  {
1216  SG_ERROR("Unfortunately you have reached the very low probability event where atleast one of "
1217  "the subsets in cross-validation is not represented at all. Please re-run.")
1218  }
1219 
1220  SGVector<int32_t> subset(train_indices->get_array(),train_indices->get_num_elements(),false);
1221  data->add_subset(subset);
1222  m_labels->add_subset(subset);
1223  SGVector<float64_t> subset_weights(train_indices->get_num_elements());
1224  for (int32_t j=0;j<train_indices->get_num_elements();j++)
1225  subset_weights[j]=m_weights[train_indices->get_element(j)];
1226 
1227  // train with training subset
1228  bnode_t* root=CARTtrain(data,subset_weights,m_labels,0);
1229 
1230  // prune trained tree
1232  tmax->set_root(root);
1233  CDynamicObjectArray* pruned_trees=prune_tree(tmax);
1234 
1235  data->remove_subset();
1237  subset=SGVector<int32_t>(test_indices->get_array(),test_indices->get_num_elements(),false);
1238  data->add_subset(subset);
1239  m_labels->add_subset(subset);
1240  subset_weights=SGVector<float64_t>(test_indices->get_num_elements());
1241  for (int32_t j=0;j<test_indices->get_num_elements();j++)
1242  subset_weights[j]=m_weights[test_indices->get_element(j)];
1243 
1244  // calculate R_CV values for each alpha_k using test subset and store them
1245  num_alphak[i]=m_alphas->get_num_elements();
1246  for (int32_t j=0;j<m_alphas->get_num_elements();j++)
1247  {
1248  alphak->push_back(m_alphas->get_element(j));
1249  CSGObject* jth_element=pruned_trees->get_element(j);
1250  bnode_t* current_root=NULL;
1251  if (jth_element!=NULL)
1252  current_root=dynamic_cast<bnode_t*>(jth_element);
1253  else
1254  SG_ERROR("%d element is NULL which should not be",j);
1255 
1256  CLabels* labels=apply_from_current_node(data, current_root);
1257  float64_t error=compute_error(labels, m_labels, subset_weights);
1258  r_cv->push_back(error);
1259  SG_UNREF(labels);
1260  SG_UNREF(jth_element);
1261  }
1262 
1263  data->remove_subset();
1265  SG_UNREF(train_indices);
1266  SG_UNREF(test_indices);
1267  SG_UNREF(tmax);
1268  SG_UNREF(pruned_trees);
1269  }
1270 
1271  // prune the original T_max
1272  CDynamicObjectArray* pruned_trees=prune_tree(this);
1273 
1274  // find subtree with minimum R_cv
1275  int32_t min_index=-1;
1277  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
1278  {
1279  float64_t alpha=0.;
1280  if (i==m_alphas->get_num_elements()-1)
1281  alpha=m_alphas->get_element(i)+1;
1282  else
1284 
1285  float64_t rv=0.;
1286  int32_t base=0;
1287  for (int32_t j=0;j<folds;j++)
1288  {
1289  bool flag=false;
1290  for (int32_t k=base;k<num_alphak[j]+base-1;k++)
1291  {
1292  if (alphak->get_element(k)<=alpha && alphak->get_element(k+1)>alpha)
1293  {
1294  rv+=r_cv->get_element(k);
1295  flag=true;
1296  break;
1297  }
1298  }
1299 
1300  if (!flag)
1301  rv+=r_cv->get_element(num_alphak[j]+base-1);
1302 
1303  base+=num_alphak[j];
1304  }
1305 
1306  if (rv<min_r_cv)
1307  {
1308  min_index=i;
1309  min_r_cv=rv;
1310  }
1311  }
1312 
1313  CSGObject* element=pruned_trees->get_element(min_index);
1314  bnode_t* best_tree_root=NULL;
1315  if (element!=NULL)
1316  best_tree_root=dynamic_cast<bnode_t*>(element);
1317  else
1318  SG_ERROR("%d element is NULL which should not be",min_index);
1319 
1320  this->set_root(best_tree_root);
1321 
1322  SG_UNREF(element);
1323  SG_UNREF(pruned_trees);
1324  SG_UNREF(r_cv);
1325  SG_UNREF(alphak);
1326 }
1327 
1329 {
1330  REQUIRE(labels,"input labels cannot be NULL");
1331  REQUIRE(reference,"reference labels cannot be NULL")
1332 
1333  CDenseLabels* gnd_truth=dynamic_cast<CDenseLabels*>(reference);
1334  CDenseLabels* result=dynamic_cast<CDenseLabels*>(labels);
1335 
1336  float64_t denom=weights.sum(weights);
1337  float64_t numer=0.;
1338  switch (m_mode)
1339  {
1340  case PT_MULTICLASS:
1341  {
1342  for (int32_t i=0;i<weights.vlen;i++)
1343  {
1344  if (gnd_truth->get_label(i)!=result->get_label(i))
1345  numer+=weights[i];
1346  }
1347 
1348  return numer/denom;
1349  }
1350 
1351  case PT_REGRESSION:
1352  {
1353  for (int32_t i=0;i<weights.vlen;i++)
1354  numer+=weights[i]*CMath::pow((gnd_truth->get_label(i)-result->get_label(i)),2);
1355 
1356  return numer/denom;
1357  }
1358 
1359  default:
1360  SG_ERROR("Case not possible\n");
1361  }
1362 
1363  return 0.;
1364 }
1365 
1367 {
1368  REQUIRE(tree, "Tree not provided for pruning.\n");
1369 
1371  SG_UNREF(m_alphas);
1373  SG_REF(m_alphas);
1374 
1375  // base tree alpha_k=0
1376  m_alphas->push_back(0);
1378  SG_REF(t1);
1379  node_t* t1root=t1->get_root();
1380  bnode_t* t1_root=NULL;
1381  if (t1root!=NULL)
1382  t1_root=dynamic_cast<bnode_t*>(t1root);
1383  else
1384  SG_ERROR("t1_root is NULL. This is not expected\n")
1385 
1386  form_t1(t1_root);
1387  trees->push_back(t1_root);
1388  while(t1_root->data.num_leaves>1)
1389  {
1391  SG_REF(t2);
1392 
1393  node_t* t2root=t2->get_root();
1394  bnode_t* t2_root=NULL;
1395  if (t2root!=NULL)
1396  t2_root=dynamic_cast<bnode_t*>(t2root);
1397  else
1398  SG_ERROR("t1_root is NULL. This is not expected\n")
1399 
1400  float64_t a_k=find_weakest_alpha(t2_root);
1401  m_alphas->push_back(a_k);
1402  cut_weakest_link(t2_root,a_k);
1403  trees->push_back(t2_root);
1404 
1405  SG_UNREF(t1);
1406  SG_UNREF(t1_root);
1407  t1=t2;
1408  t1_root=t2_root;
1409  }
1410 
1411  SG_UNREF(t1);
1412  SG_UNREF(t1_root);
1413  return trees;
1414 }
1415 
1417 {
1418  if (node->data.num_leaves!=1)
1419  {
1420  bnode_t* left=node->left();
1421  bnode_t* right=node->right();
1422 
1423  SGVector<float64_t> weak_links(3);
1424  weak_links[0]=find_weakest_alpha(left);
1425  weak_links[1]=find_weakest_alpha(right);
1426  weak_links[2]=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1427  weak_links[2]/=(node->data.num_leaves-1.0);
1428 
1429  SG_UNREF(left);
1430  SG_UNREF(right);
1431  return CMath::min(weak_links.vector,weak_links.vlen);
1432  }
1433 
1434  return CMath::MAX_REAL_NUMBER;
1435 }
1436 
1438 {
1439  if (node->data.num_leaves==1)
1440  return;
1441 
1442  float64_t g=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1443  g/=(node->data.num_leaves-1.0);
1444  if (alpha==g)
1445  {
1446  node->data.num_leaves=1;
1447  node->data.weight_minus_branch=node->data.weight_minus_node;
1448  CDynamicObjectArray* children=new CDynamicObjectArray();
1449  node->set_children(children);
1450 
1451  SG_UNREF(children);
1452  }
1453  else
1454  {
1455  bnode_t* left=node->left();
1456  bnode_t* right=node->right();
1457  cut_weakest_link(left,alpha);
1458  cut_weakest_link(right,alpha);
1459  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1460  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1461 
1462  SG_UNREF(left);
1463  SG_UNREF(right);
1464  }
1465 }
1466 
1468 {
1469  if (node->data.num_leaves!=1)
1470  {
1471  bnode_t* left=node->left();
1472  bnode_t* right=node->right();
1473 
1474  form_t1(left);
1475  form_t1(right);
1476 
1477  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1478  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1479  if (node->data.weight_minus_node==node->data.weight_minus_branch)
1480  {
1481  node->data.num_leaves=1;
1482  CDynamicObjectArray* children=new CDynamicObjectArray();
1483  node->set_children(children);
1484 
1485  SG_UNREF(children);
1486  }
1487 
1488  SG_UNREF(left);
1489  SG_UNREF(right);
1490  }
1491 }
1492 
1494 {
1498  m_pre_sort=false;
1499  m_types_set=false;
1500  m_weights_set=false;
1501  m_apply_cv_pruning=false;
1502  m_folds=5;
1504  SG_REF(m_alphas);
1505  m_max_depth=0;
1506  m_min_node_size=0;
1507  m_label_epsilon=1e-7;
1510 
1511  SG_ADD(&m_pre_sort, "m_pre_sort", "presort", MS_NOT_AVAILABLE);
1512  SG_ADD(&m_sorted_features, "m_sorted_features", "sorted feats", MS_NOT_AVAILABLE);
1513  SG_ADD(&m_sorted_indices, "m_sorted_indices", "sorted indices", MS_NOT_AVAILABLE);
1514  SG_ADD(&m_nominal, "m_nominal", "feature types", MS_NOT_AVAILABLE);
1515  SG_ADD(&m_weights, "m_weights", "weights", MS_NOT_AVAILABLE);
1516  SG_ADD(&m_weights_set, "m_weights_set", "weights set", MS_NOT_AVAILABLE);
1517  SG_ADD(&m_types_set, "m_types_set", "feature types set", MS_NOT_AVAILABLE);
1518  SG_ADD(&m_apply_cv_pruning, "m_apply_cv_pruning", "apply cross validation pruning", MS_NOT_AVAILABLE);
1519  SG_ADD(&m_folds, "m_folds", "number of subsets for cross validation", MS_NOT_AVAILABLE);
1520  SG_ADD(&m_max_depth, "m_max_depth", "max allowed tree depth", MS_NOT_AVAILABLE)
1521  SG_ADD(&m_min_node_size, "m_min_node_size", "min allowed node size", MS_NOT_AVAILABLE)
1522  SG_ADD(&m_label_epsilon, "m_label_epsilon", "epsilon for labels", MS_NOT_AVAILABLE)
1523  SG_ADD((machine_int_t*)&m_mode, "m_mode", "problem type (multiclass or regression)", MS_NOT_AVAILABLE)
1524 }
void set_cv_pruning(bool cv_pruning)
Definition: CARTree.h:214
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:1104
bool m_types_set
Definition: CARTree.h:444
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
Definition: CARTree.cpp:530
static void permute(SGVector< T > v, CRandom *rand=NULL)
Definition: Math.h:962
bool set_element(T e, int32_t idx1, int32_t idx2=0, int32_t idx3=0)
Definition: DynamicArray.h:306
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
int32_t get_max_depth() const
Definition: CARTree.cpp:218
static void random_vector(T *vec, int32_t len, T min_value, T max_value)
Definition: SGVector.cpp:620
virtual ELabelType get_label_type() const =0
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:173
Real Labels are real-valued labels.
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:196
int32_t get_min_node_size() const
Definition: CARTree.cpp:229
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:86
int32_t get_num_folds() const
Definition: CARTree.cpp:207
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1366
int32_t index_t
Definition: common.h:72
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
SGMatrix< index_t > m_sorted_indices
Definition: CARTree.h:438
virtual int32_t get_num_labels() const =0
real valued labels (e.g. for regression, classifier outputs)
Definition: LabelTypes.h:22
static void qsort_index(T1 *output, T2 *index, uint32_t size)
Definition: Math.h:2022
static T sqrt(T x)
Definition: Math.h:428
multi-class labels 0,1,...
Definition: LabelTypes.h:20
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:101
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1416
static T sum(T *vec, int32_t len)
Return sum(vec)
Definition: SGVector.h:418
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1467
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:91
Definition: SGMatrix.h:25
CLabels * m_labels
Definition: Machine.h:436
int32_t get_num_elements() const
Definition: DynamicArray.h:200
#define SG_ERROR(...)
Definition: SGIO.h:128
#define REQUIRE(x,...)
Definition: SGIO.h:181
static const float64_t EQ_DELTA
Definition: CARTree.h:422
bool m_apply_cv_pruning
Definition: CARTree.h:450
virtual ~CCARTree()
Definition: CARTree.cpp:67
CTreeMachineNode< CARTreeNodeData > * get_root()
Definition: TreeMachine.h:88
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:840
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:246
const T & get_element(int32_t idx1, int32_t idx2=0, int32_t idx3=0) const
Definition: DynamicArray.h:212
float64_t get_label(int32_t idx)
int32_t m_max_depth
Definition: CARTree.h:462
std::enable_if<!std::is_same< T, complex128_t >::value, float64_t >::type mean(const Container< T > &a)
float64_t m_label_epsilon
Definition: CARTree.h:426
#define SG_REF(x)
Definition: SGObject.h:52
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:72
static SGVector< index_t > argsort(SGVector< T > vector)
Definition: Math.h:1418
void set_root(CTreeMachineNode< CARTreeNodeData > *root)
Definition: TreeMachine.h:78
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
class to add subset support to another class. A CSubsetStackStack instance should be added and wrappe...
Definition: SubsetStack.h:37
T * get_array() const
Definition: DynamicArray.h:408
Multiclass Labels for multi-class classification.
virtual void set_children(CDynamicObjectArray *children)
CSubset * get_last_subset() const
Definition: SubsetStack.h:98
static bool fequals(const T &a, const T &b, const float64_t eps, bool tolerant=false)
Definition: Math.h:308
virtual CSubsetStack * get_subset_stack()
Definition: Features.cpp:334
void clear_feature_types()
Definition: CARTree.cpp:201
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:1052
EProblemType
Definition: Machine.h:113
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:124
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
Definition: CARTree.cpp:1088
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:234
void right(CBinaryTreeMachineNode *r)
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:289
int32_t size() const
Definition: SGVector.h:156
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:915
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:212
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:1078
double float64_t
Definition: common.h:60
int32_t m_min_node_size
Definition: CARTree.h:465
int32_t m_folds
Definition: CARTree.h:453
bool m_weights_set
Definition: CARTree.h:447
virtual void remove_subset()
Definition: Labels.cpp:51
void range_fill(T start=0)
Definition: SGVector.cpp:223
index_t num_rows
Definition: SGMatrix.h:495
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:179
CTreeMachine * clone_tree()
Definition: TreeMachine.h:97
void clear_weights()
Definition: CARTree.cpp:184
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:41
virtual EFeatureClass get_feature_class() const =0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
SGVector< T > clone() const
Definition: SGVector.cpp:262
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:279
index_t num_cols
Definition: SGMatrix.h:497
SGMatrix< float64_t > m_sorted_features
Definition: CARTree.h:435
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:316
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:223
void set_const(T const_elem)
Definition: SGVector.cpp:199
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:972
structure to store data of a node of CART. This can be used as a template type in TreeMachineNode cla...
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:296
#define SG_UNREF(x)
Definition: SGObject.h:53
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
T sum(const Container< T > &a, bool no_diag=false)
virtual void remove_subset()
Definition: Features.cpp:322
int machine_int_t
Definition: common.h:69
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:190
CBinaryTreeMachineNode< CARTreeNodeData > bnode_t
Definition: TreeMachine.h:55
The class Features is the base class of all feature objects.
Definition: Features.h:69
SGVector< bool > m_nominal
Definition: CARTree.h:429
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1328
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1437
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:419
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:115
virtual bool has_subsets() const
Definition: SubsetStack.h:89
static float base
Definition: JLCoverTree.h:89
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:127
CSGObject * get_element(int32_t index) const
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
#define SG_WARNING(...)
Definition: SGIO.h:127
#define SG_ADD(...)
Definition: SGObject.h:93
SGVector< float64_t > m_weights
Definition: CARTree.h:432
Dense integer or floating point labels.
Definition: DenseLabels.h:35
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:459
T max(const Container< T > &a)
virtual int32_t get_num_vectors() const
static T min(T a, T b)
Definition: Math.h:138
static int32_t unique(T *output, int32_t size)
Definition: SGVector.cpp:841
void left(CBinaryTreeMachineNode *l)
static int32_t pow(bool x, int32_t n)
Definition: Math.h:474
T * get_column_vector(index_t col) const
Definition: SGMatrix.h:144
static const float64_t MAX_REAL_NUMBER
Definition: Math.h:1881
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
index_t vlen
Definition: SGVector.h:571
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:506
static T abs(T a)
Definition: Math.h:161
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:240
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1189
static const float64_t MISSING
Definition: CARTree.h:416
static void qsort(T *output, int32_t size)
Definition: Math.h:1134
EProblemType m_mode
Definition: CARTree.h:456

SHOGUN Machine Learning Toolbox - Documentation