SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
CARTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
33 
34 using namespace shogun;
35 
37 const float64_t CCARTree::EQ_DELTA=1e-7;
39 
42 {
43  init();
44 }
45 
46 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type)
48 {
49  init();
50  set_feature_types(attribute_types);
51  set_machine_problem_type(prob_type);
52 }
53 
54 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune)
56 {
57  init();
58  set_feature_types(attribute_types);
59  set_machine_problem_type(prob_type);
60  set_num_folds(num_folds);
61  if (cv_prune)
63 }
64 
66 {
68 }
69 
71 {
72  if (lab->get_label_type()==LT_MULTICLASS)
74  else if (lab->get_label_type()==LT_REGRESSION)
76  else
77  SG_ERROR("label type supplied is not supported\n")
78 
79  SG_REF(lab);
81  m_labels=lab;
82 }
83 
85 {
86  m_mode=mode;
87 }
88 
90 {
92  return true;
93  else if (m_mode==PT_REGRESSION && lab->get_label_type()==LT_REGRESSION)
94  return true;
95  else
96  return false;
97 }
98 
100 {
101  REQUIRE(data, "Data required for classification in apply_multiclass\n")
102 
103  // apply multiclass starting from root
104  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
105  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
106 
107  SG_UNREF(current);
108  return dynamic_cast<CMulticlassLabels*>(ret);
109 }
110 
112 {
113  REQUIRE(data, "Data required for classification in apply_multiclass\n")
114 
115  // apply regression starting from root
116  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
117  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
118 
119  SG_UNREF(current);
120  return dynamic_cast<CRegressionLabels*>(ret);
121 }
122 
124 {
125  if (weights.vlen==0)
126  {
127  weights=SGVector<float64_t>(feats->get_num_vectors());
128  weights.fill_vector(weights.vector,weights.vlen,1);
129  }
130 
131  CDynamicObjectArray* pruned_trees=prune_tree(this);
132 
133  int32_t min_index=0;
135  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
136  {
137  CSGObject* element=pruned_trees->get_element(i);
138  bnode_t* root=NULL;
139  if (element!=NULL)
140  root=dynamic_cast<bnode_t*>(element);
141  else
142  SG_ERROR("%d element is NULL\n",i);
143 
144  CLabels* labels=apply_from_current_node(feats, root);
145  float64_t error=compute_error(labels,gnd_truth,weights);
146  if (error<min_error)
147  {
148  min_index=i;
149  min_error=error;
150  }
151 
152  SG_UNREF(labels);
153  SG_UNREF(element);
154  }
155 
156  CSGObject* element=pruned_trees->get_element(min_index);
157  bnode_t* root=NULL;
158  if (element!=NULL)
159  root=dynamic_cast<bnode_t*>(element);
160  else
161  SG_ERROR("%d element is NULL\n",min_index);
162 
163  this->set_root(root);
164 
165  SG_UNREF(pruned_trees);
166  SG_UNREF(element);
167 }
168 
170 {
171  m_weights=w;
172  m_weights_set=true;
173 }
174 
176 {
177  return m_weights;
178 }
179 
181 {
183  m_weights_set=false;
184 }
185 
187 {
188  m_nominal=ft;
189  m_types_set=true;
190 }
191 
193 {
194  return m_nominal;
195 }
196 
198 {
200  m_types_set=false;
201 }
202 
203 int32_t CCARTree::get_num_folds() const
204 {
205  return m_folds;
206 }
207 
208 void CCARTree::set_num_folds(int32_t folds)
209 {
210  REQUIRE(folds>1,"Number of folds is expected to be greater than 1. Supplied value is %d\n",folds)
211  m_folds=folds;
212 }
213 
214 int32_t CCARTree::get_max_depth() const
215 {
216  return m_max_depth;
217 }
218 
219 void CCARTree::set_max_depth(int32_t depth)
220 {
221  REQUIRE(depth>0,"Max allowed tree depth should be greater than 0. Supplied value is %d\n",depth)
222  m_max_depth=depth;
223 }
224 
226 {
227  return m_min_node_size;
228 }
229 
230 void CCARTree::set_min_node_size(int32_t nsize)
231 {
232  REQUIRE(nsize>0,"Min allowed node size should be greater than 0. Supplied value is %d\n",nsize)
233  m_min_node_size=nsize;
234 }
235 
237 {
238  REQUIRE(ep>=0,"Input epsilon value is expected to be greater than or equal to 0\n")
239  m_label_epsilon=ep;
240 }
241 
243 {
244  REQUIRE(data,"Data required for training\n")
245  REQUIRE(data->get_feature_class()==C_DENSE,"Dense data required for training\n")
246 
247  int32_t num_features=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_features();
248  int32_t num_vectors=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_vectors();
249 
250  if (m_weights_set)
251  {
252  REQUIRE(m_weights.vlen==num_vectors,"Length of weights vector (currently %d) should be same as"
253  " number of vectors in data (presently %d)",m_weights.vlen,num_vectors)
254  }
255  else
256  {
257  // all weights are equal to 1
258  m_weights=SGVector<float64_t>(num_vectors);
260  }
261 
262  if (m_types_set)
263  {
264  REQUIRE(m_nominal.vlen==num_features,"Length of m_nominal vector (currently %d) should "
265  "be same as number of features in data (presently %d)",m_nominal.vlen,num_features)
266  }
267  else
268  {
269  SG_WARNING("Feature types are not specified. All features are considered as continuous in training")
270  m_nominal=SGVector<bool>(num_features);
272  }
273 
275 
276  if (m_apply_cv_pruning)
277  {
278  CDenseFeatures<float64_t>* feats=dynamic_cast<CDenseFeatures<float64_t>*>(data);
280  }
281 
282  return true;
283 }
284 
286 {
287  REQUIRE(labels,"labels have to be supplied\n");
288  REQUIRE(data,"data matrix has to be supplied\n");
289 
290  bnode_t* node=new bnode_t();
291  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
292  SGMatrix<float64_t> mat=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_feature_matrix();
293  int32_t num_feats=mat.num_rows;
294  int32_t num_vecs=mat.num_cols;
295 
296  // calculate node label
297  switch(m_mode)
298  {
299  case PT_REGRESSION:
300  {
301  float64_t sum=0;
302  for (int32_t i=0;i<labels_vec.vlen;i++)
303  sum+=labels_vec[i]*weights[i];
304 
305  // lsd*total_weight=sum_of_squared_deviation
306  float64_t tot=0;
307  node->data.weight_minus_node=tot*least_squares_deviation(labels_vec,weights,tot);
308  node->data.node_label=sum/tot;
309  node->data.total_weight=tot;
310 
311  break;
312  }
313  case PT_MULTICLASS:
314  {
315  SGVector<float64_t> lab=labels_vec.clone();
316  CMath::qsort(lab);
317  // stores max total weight for a single label
318  int32_t max=weights[0];
319  // stores one of the indices having max total weight
320  int32_t maxi=0;
321  int32_t c=weights[0];
322  for (int32_t i=1;i<lab.vlen;i++)
323  {
324  if (lab[i]==lab[i-1])
325  {
326  c+=weights[i];
327  }
328  else if (c>max)
329  {
330  max=c;
331  maxi=i-1;
332  c=weights[i];
333  }
334  else
335  {
336  c=weights[i];
337  }
338  }
339 
340  if (c>max)
341  {
342  max=c;
343  maxi=lab.vlen-1;
344  }
345 
346  node->data.node_label=lab[maxi];
347 
348  // resubstitution error calculation
349  node->data.total_weight=weights.sum(weights);
350  node->data.weight_minus_node=node->data.total_weight-max;
351  break;
352  }
353  default :
354  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
355  }
356 
357  // check stopping rules
358  // case 1 : max tree depth reached if max_depth set
359  if ((m_max_depth>0) && (level==m_max_depth))
360  {
361  node->data.num_leaves=1;
362  node->data.weight_minus_branch=node->data.weight_minus_node;
363  return node;
364  }
365 
366  // case 2 : min node size violated if min_node_size specified
367  if ((m_min_node_size>1) && (labels_vec.vlen<=m_min_node_size))
368  {
369  node->data.num_leaves=1;
370  node->data.weight_minus_branch=node->data.weight_minus_node;
371  return node;
372  }
373 
374  // choose best attribute
375  // transit_into_values for left child
376  SGVector<float64_t> left(num_feats);
377  // transit_into_values for right child
378  SGVector<float64_t> right(num_feats);
379  // final data distribution among children
380  SGVector<bool> left_final(num_vecs);
381  int32_t num_missing_final=0;
382  int32_t c_left=-1;
383  int32_t c_right=-1;
384 
385  int32_t best_attribute=compute_best_attribute(mat,weights,labels_vec,left,right,left_final,num_missing_final,c_left,c_right);
386 
387  if (best_attribute==-1)
388  {
389  node->data.num_leaves=1;
390  node->data.weight_minus_branch=node->data.weight_minus_node;
391  return node;
392  }
393 
394  SGVector<float64_t> left_transit(c_left);
395  SGVector<float64_t> right_transit(c_right);
396  memcpy(left_transit.vector,left.vector,c_left*sizeof(float64_t));
397  memcpy(right_transit.vector,right.vector,c_right*sizeof(float64_t));
398 
399  if (num_missing_final>0)
400  {
401  SGVector<bool> is_left_final(num_vecs-num_missing_final);
402  int32_t ilf=0;
403  for (int32_t i=0;i<num_vecs;i++)
404  {
405  if (mat(best_attribute,i)!=MISSING)
406  is_left_final[ilf++]=left_final[i];
407  }
408 
409  left_final=surrogate_split(mat,weights,is_left_final,best_attribute);
410  }
411 
412  int32_t count_left=0;
413  for (int32_t c=0;c<num_vecs;c++)
414  count_left=(left_final[c])?count_left+1:count_left;
415 
416  SGVector<index_t> subsetl(count_left);
417  SGVector<float64_t> weightsl(count_left);
418  SGVector<index_t> subsetr(num_vecs-count_left);
419  SGVector<float64_t> weightsr(num_vecs-count_left);
420  index_t l=0;
421  index_t r=0;
422  for (int32_t c=0;c<num_vecs;c++)
423  {
424  if (left_final[c])
425  {
426  subsetl[l]=c;
427  weightsl[l++]=weights[c];
428  }
429  else
430  {
431  subsetr[r]=c;
432  weightsr[r++]=weights[c];
433  }
434  }
435 
436  // left child
437  data->add_subset(subsetl);
438  labels->add_subset(subsetl);
439  bnode_t* left_child=CARTtrain(data,weightsl,labels,level+1);
440  data->remove_subset();
441  labels->remove_subset();
442 
443  // right child
444  data->add_subset(subsetr);
445  labels->add_subset(subsetr);
446  bnode_t* right_child=CARTtrain(data,weightsr,labels,level+1);
447  data->remove_subset();
448  labels->remove_subset();
449 
450  // set node parameters
451  node->data.attribute_id=best_attribute;
452  node->left(left_child);
453  node->right(right_child);
454  left_child->data.transit_into_values=left_transit;
455  right_child->data.transit_into_values=right_transit;
456  node->data.num_leaves=left_child->data.num_leaves+right_child->data.num_leaves;
457  node->data.weight_minus_branch=left_child->data.weight_minus_branch+right_child->data.weight_minus_branch;
458 
459  return node;
460 }
461 
463 {
464  float64_t delta=0;
465  if (m_mode==PT_REGRESSION)
466  delta=m_label_epsilon;
467 
468  SGVector<float64_t> ulabels(labels_vec.vlen);
469  SGVector<index_t> sidx=CMath::argsort(labels_vec);
470  ulabels[0]=labels_vec[sidx[0]];
471  n_ulabels=1;
472  int32_t start=0;
473  for (int32_t i=1;i<sidx.vlen;i++)
474  {
475  if (labels_vec[sidx[i]]<=labels_vec[sidx[start]]+delta)
476  continue;
477 
478  start=i;
479  ulabels[n_ulabels]=labels_vec[sidx[i]];
480  n_ulabels++;
481  }
482 
483  return ulabels;
484 }
485 
487  SGVector<float64_t> left, SGVector<float64_t> right, SGVector<bool> is_left_final, int32_t &num_missing_final, int32_t &count_left,
488  int32_t &count_right)
489 {
490  int32_t num_vecs=mat.num_cols;
491  int32_t num_feats=mat.num_rows;
492 
493  int32_t n_ulabels;
494  SGVector<float64_t> ulabels=get_unique_labels(labels_vec,n_ulabels);
495 
496  // if all labels same early stop
497  if (n_ulabels==1)
498  return -1;
499 
500  float64_t delta=0;
501  if (m_mode==PT_REGRESSION)
502  delta=m_label_epsilon;
503 
504  SGVector<float64_t> total_wclasses(n_ulabels);
505  total_wclasses.zero();
506 
507  SGVector<int32_t> simple_labels(num_vecs);
508  for (int32_t i=0;i<num_vecs;i++)
509  {
510  for (int32_t j=0;j<n_ulabels;j++)
511  {
512  if (CMath::abs(labels_vec[i]-ulabels[j])<=delta)
513  {
514  simple_labels[i]=j;
515  total_wclasses[j]+=weights[i];
516  break;
517  }
518  }
519  }
520 
521  float64_t max_gain=MIN_SPLIT_GAIN;
522  int32_t best_attribute=-1;
523  float64_t best_threshold=0;
524  for (int32_t i=0;i<num_feats;i++)
525  {
526  SGVector<float64_t> feats(num_vecs);
527  for (int32_t j=0;j<num_vecs;j++)
528  feats[j]=mat(i,j);
529 
530  // O(N*logN)
531  SGVector<index_t> sorted_args=CMath::argsort(feats);
532 
533  // number of non-missing vecs
534  int32_t n_nm_vecs=feats.vlen;
535  while (feats[sorted_args[n_nm_vecs-1]]==MISSING)
536  {
537  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]-=weights[sorted_args[n_nm_vecs-1]];
538  n_nm_vecs--;
539  }
540 
541  // if only one unique value - it cannot be used to split
542  if (feats[sorted_args[n_nm_vecs-1]]<=feats[sorted_args[0]]+EQ_DELTA)
543  continue;
544 
545  if (m_nominal[i])
546  {
547  SGVector<int32_t> simple_feats(num_vecs);
548  simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1);
549 
550  // convert to simple values
551  simple_feats[sorted_args[0]]=0;
552  int32_t c=0;
553  for (int32_t j=1;j<n_nm_vecs;j++)
554  {
555  if (feats[sorted_args[j]]==feats[sorted_args[j-1]])
556  simple_feats[sorted_args[j]]=c;
557  else
558  simple_feats[sorted_args[j]]=(++c);
559  }
560 
561  SGVector<float64_t> ufeats(c+1);
562  ufeats[0]=feats[sorted_args[0]];
563  int32_t u=0;
564  for (int32_t j=1;j<n_nm_vecs;j++)
565  {
566  if (feats[sorted_args[j]]==feats[sorted_args[j-1]])
567  continue;
568  else
569  ufeats[++u]=feats[sorted_args[j]];
570  }
571 
572  // test all 2^(I-1)-1 possible division between two nodes
573  int32_t num_cases=CMath::pow(2,c);
574  for (int32_t k=1;k<num_cases;k++)
575  {
576  SGVector<float64_t> wleft(n_ulabels);
577  SGVector<float64_t> wright(n_ulabels);
578  wleft.zero();
579  wright.zero();
580 
581  // stores which vectors are assigned to left child
582  SGVector<bool> is_left(num_vecs);
583  is_left.fill_vector(is_left.vector,is_left.vlen,false);
584 
585  // stores which among the categorical values of chosen attribute are assigned left child
586  SGVector<bool> feats_left(c+1);
587 
588  // fill feats_left in a unique way corresponding to the case
589  for (int32_t p=0;p<c+1;p++)
590  feats_left[p]=((k/CMath::pow(2,p))%(CMath::pow(2,p+1))==1);
591 
592  // form is_left
593  for (int32_t j=0;j<n_nm_vecs;j++)
594  {
595  is_left[sorted_args[j]]=feats_left[simple_feats[sorted_args[j]]];
596  if (is_left[sorted_args[j]])
597  wleft[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
598  else
599  wright[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
600  }
601 
602  float64_t g=0;
603  if (m_mode==PT_MULTICLASS)
604  g=gain(wleft,wright,total_wclasses);
605  else if (m_mode==PT_REGRESSION)
606  g=gain(wleft,wright,total_wclasses,ulabels);
607  else
608  SG_ERROR("Undefined problem statement\n");
609 
610  if (g>max_gain)
611  {
612  best_attribute=i;
613  max_gain=g;
614  memcpy(is_left_final.vector,is_left.vector,is_left.vlen*sizeof(bool));
615  num_missing_final=num_vecs-n_nm_vecs;
616 
617  count_left=0;
618  for (int32_t l=0;l<c+1;l++)
619  count_left=(feats_left[l])?count_left+1:count_left;
620 
621  count_right=c+1-count_left;
622 
623  int32_t l=0;
624  int32_t r=0;
625  for (int32_t w=0;w<c+1;w++)
626  {
627  if (feats_left[w])
628  left[l++]=ufeats[w];
629  else
630  right[r++]=ufeats[w];
631  }
632  }
633  }
634  }
635  else
636  {
637  // O(N)
638  SGVector<float64_t> right_wclasses=total_wclasses.clone();
639  SGVector<float64_t> left_wclasses(n_ulabels);
640  left_wclasses.zero();
641 
642  // O(N)
643  // find best split for non-nominal attribute - choose threshold (z)
644  float64_t z=feats[sorted_args[0]];
645  right_wclasses[simple_labels[sorted_args[0]]]-=weights[sorted_args[0]];
646  left_wclasses[simple_labels[sorted_args[0]]]+=weights[sorted_args[0]];
647  for (int32_t j=1;j<n_nm_vecs;j++)
648  {
649  if (feats[sorted_args[j]]<=z+EQ_DELTA)
650  {
651  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
652  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
653  continue;
654  }
655 
656  // O(F)
657  float64_t g=0;
658  if (m_mode==PT_MULTICLASS)
659  g=gain(left_wclasses,right_wclasses,total_wclasses);
660  else if (m_mode==PT_REGRESSION)
661  g=gain(left_wclasses,right_wclasses,total_wclasses,ulabels);
662  else
663  SG_ERROR("Undefined problem statement\n");
664 
665  if (g>max_gain)
666  {
667  max_gain=g;
668  best_attribute=i;
669  best_threshold=z;
670  num_missing_final=num_vecs-n_nm_vecs;
671  }
672 
673  z=feats[sorted_args[j]];
674  if (feats[sorted_args[n_nm_vecs-1]]<=z+EQ_DELTA)
675  break;
676 
677  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
678  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
679  }
680  }
681 
682  // restore total_wclasses
683  while (n_nm_vecs<feats.vlen)
684  {
685  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]+=weights[sorted_args[n_nm_vecs-1]];
686  n_nm_vecs++;
687  }
688  }
689 
690  if (best_attribute==-1)
691  return -1;
692 
693  if (!m_nominal[best_attribute])
694  {
695  left[0]=best_threshold;
696  right[0]=best_threshold;
697  count_left=1;
698  count_right=1;
699  for (int32_t i=0;i<num_vecs;i++)
700  is_left_final[i]=(mat(best_attribute,i)<=best_threshold);
701  }
702 
703  return best_attribute;
704 }
705 
707 {
708  // return vector - left/right belongingness
709  SGVector<bool> ret(m.num_cols);
710 
711  // ditribute data with known attributes
712  int32_t l=0;
713  float64_t p_l=0.;
714  float64_t total=0.;
715  // stores indices of vectors with missing attribute
717  // stores lambda values corresponding to missing vectors - initialized all with 0
718  CDynamicArray<float64_t>* association_index=new CDynamicArray<float64_t>();
719  for (int32_t i=0;i<m.num_cols;i++)
720  {
721  if (!CMath::fequals(m(attr,i),MISSING,0))
722  {
723  ret[i]=nm_left[l];
724  total+=weights[i];
725  if (nm_left[l++])
726  p_l+=weights[i];
727  }
728  else
729  {
730  missing_vecs->push_back(i);
731  association_index->push_back(0.);
732  }
733  }
734 
735  // for lambda calculation
736  float64_t p_r=(total-p_l)/total;
737  p_l/=total;
738  float64_t p=CMath::min(p_r,p_l);
739 
740  // for each attribute (X') alternative to best split (X)
741  for (int32_t i=0;i<m.num_rows;i++)
742  {
743  if (i==attr)
744  continue;
745 
746  // find set of vectors with non-missing values for both X and X'
747  CDynamicArray<int32_t>* intersect_vecs=new CDynamicArray<int32_t>();
748  for (int32_t j=0;j<m.num_cols;j++)
749  {
750  if (!(CMath::fequals(m(i,j),MISSING,0) || CMath::fequals(m(attr,j),MISSING,0)))
751  intersect_vecs->push_back(j);
752  }
753 
754  if (intersect_vecs->get_num_elements()==0)
755  {
756  SG_UNREF(intersect_vecs);
757  continue;
758  }
759 
760 
761  if (m_nominal[i])
762  handle_missing_vecs_for_nominal_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
763  else
764  handle_missing_vecs_for_continuous_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
765 
766  SG_UNREF(intersect_vecs);
767  }
768 
769  // if some missing attribute vectors are yet not addressed, use majority rule
770  for (int32_t i=0;i<association_index->get_num_elements();i++)
771  {
772  if (association_index->get_element(i)==0.)
773  ret[missing_vecs->get_element(i)]=(p_l>=p_r);
774  }
775 
776  SG_UNREF(missing_vecs);
777  SG_UNREF(association_index);
778  return ret;
779 }
780 
782  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
783  SGVector<float64_t> weights, float64_t p, int32_t attr)
784 {
785  // for lambda calculation - total weight of all vectors in X intersect X'
786  float64_t denom=0.;
787  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
788  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
789  {
790  feats[j]=m(attr,intersect_vecs->get_element(j));
791  denom+=weights[intersect_vecs->get_element(j)];
792  }
793 
794  // unique feature values for X'
795  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
796 
797 
798  // all possible splits for chosen attribute
799  for (int32_t j=0;j<num_unique-1;j++)
800  {
801  float64_t z=feats[j];
802  float64_t numer=0.;
803  float64_t numerc=0.;
804  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
805  {
806  // if both go left or both go right
807  if ((m(attr,intersect_vecs->get_element(k))<=z) && is_left[intersect_vecs->get_element(k)])
808  numer+=weights[intersect_vecs->get_element(k)];
809  else if ((m(attr,intersect_vecs->get_element(k))>z) && !is_left[intersect_vecs->get_element(k)])
810  numer+=weights[intersect_vecs->get_element(k)];
811  // complementary split cases - one goes left other right
812  else if ((m(attr,intersect_vecs->get_element(k))<=z) && !is_left[intersect_vecs->get_element(k)])
813  numerc+=weights[intersect_vecs->get_element(k)];
814  else if ((m(attr,intersect_vecs->get_element(k))>z) && is_left[intersect_vecs->get_element(k)])
815  numerc+=weights[intersect_vecs->get_element(k)];
816  }
817 
818  float64_t lambda=0.;
819  if (numer>=numerc)
820  lambda=(p-(1-numer/denom))/p;
821  else
822  lambda=(p-(1-numerc/denom))/p;
823  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
824  {
825  if ((lambda>association_index->get_element(k)) &&
826  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
827  {
828  association_index->set_element(lambda,k);
829  if (numer>=numerc)
830  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))<=z);
831  else
832  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))>z);
833  }
834  }
835  }
836 }
837 
839  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
840  SGVector<float64_t> weights, float64_t p, int32_t attr)
841 {
842  // for lambda calculation - total weight of all vectors in X intersect X'
843  float64_t denom=0.;
844  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
845  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
846  {
847  feats[j]=m(attr,intersect_vecs->get_element(j));
848  denom+=weights[intersect_vecs->get_element(j)];
849  }
850 
851  // unique feature values for X'
852  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
853 
854  // scan all splits for chosen alternative attribute X'
855  int32_t num_cases=CMath::pow(2,(num_unique-1));
856  for (int32_t j=1;j<num_cases;j++)
857  {
858  SGVector<bool> feats_left(num_unique);
859  for (int32_t k=0;k<num_unique;k++)
860  feats_left[k]=((j/CMath::pow(2,k))%(CMath::pow(2,k+1))==1);
861 
862  SGVector<bool> intersect_vecs_left(intersect_vecs->get_num_elements());
863  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
864  {
865  for (int32_t q=0;q<num_unique;q++)
866  {
867  if (feats[q]==m(attr,intersect_vecs->get_element(k)))
868  {
869  intersect_vecs_left[k]=feats_left[q];
870  break;
871  }
872  }
873  }
874 
875  float64_t numer=0.;
876  float64_t numerc=0.;
877  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
878  {
879  // if both go left or both go right
880  if (intersect_vecs_left[k]==is_left[intersect_vecs->get_element(k)])
881  numer+=weights[intersect_vecs->get_element(k)];
882  else
883  numerc+=weights[intersect_vecs->get_element(k)];
884  }
885 
886  // lambda for this split (2 case identical split/complementary split)
887  float64_t lambda=0.;
888  if (numer>=numerc)
889  lambda=(p-(1-numer/denom))/p;
890  else
891  lambda=(p-(1-numerc/denom))/p;
892 
893  // address missing value vectors not yet addressed or addressed using worse split
894  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
895  {
896  if ((lambda>association_index->get_element(k)) &&
897  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
898  {
899  association_index->set_element(lambda,k);
900  // decide left/right based on which feature value the chosen data point has
901  for (int32_t q=0;q<num_unique;q++)
902  {
903  if (feats[q]==m(attr,missing_vecs->get_element(k)))
904  {
905  if (numer>=numerc)
906  is_left[missing_vecs->get_element(k)]=feats_left[q];
907  else
908  is_left[missing_vecs->get_element(k)]=!feats_left[q];
909 
910  break;
911  }
912  }
913  }
914  }
915  }
916 }
917 
919  SGVector<float64_t> feats)
920 {
921  float64_t total_lweight=0;
922  float64_t total_rweight=0;
923  float64_t total_weight=0;
924 
925  float64_t lsd_n=least_squares_deviation(feats,wtotal,total_weight);
926  float64_t lsd_l=least_squares_deviation(feats,wleft,total_lweight);
927  float64_t lsd_r=least_squares_deviation(feats,wright,total_rweight);
928 
929  return lsd_n-(lsd_l*(total_lweight/total_weight))-(lsd_r*(total_rweight/total_weight));
930 }
931 
933 {
934  float64_t total_lweight=0;
935  float64_t total_rweight=0;
936  float64_t total_weight=0;
937 
938  float64_t gini_n=gini_impurity_index(wtotal,total_weight);
939  float64_t gini_l=gini_impurity_index(wleft,total_lweight);
940  float64_t gini_r=gini_impurity_index(wright,total_rweight);
941  return gini_n-(gini_l*(total_lweight/total_weight))-(gini_r*(total_rweight/total_weight));
942 }
943 
945 {
946  total_weight=0;
947  float64_t gini=0;
948  for (int32_t i=0;i<weighted_lab_classes.vlen;i++)
949  {
950  total_weight+=weighted_lab_classes[i];
951  gini+=weighted_lab_classes[i]*weighted_lab_classes[i];
952  }
953 
954  gini=1.0-(gini/(total_weight*total_weight));
955  return gini;
956 }
957 
959 {
960  float64_t mean=0;
961  total_weight=0;
962  for (int32_t i=0;i<weights.vlen;i++)
963  {
964  mean+=feats[i]*weights[i];
965  total_weight+=weights[i];
966  }
967 
968  mean/=total_weight;
969  float64_t dev=0;
970  for (int32_t i=0;i<weights.vlen;i++)
971  dev+=weights[i]*(feats[i]-mean)*(feats[i]-mean);
972 
973  return dev/total_weight;
974 }
975 
977 {
978  int32_t num_vecs=feats->get_num_vectors();
979  SGVector<float64_t> labels(num_vecs);
980  for (int32_t i=0;i<num_vecs;i++)
981  {
982  SGVector<float64_t> sample=feats->get_feature_vector(i);
983  bnode_t* node=current;
984  SG_REF(node);
985 
986  // until leaf is reached
987  while(node->data.num_leaves!=1)
988  {
989  bnode_t* leftchild=node->left();
990 
991  if (m_nominal[node->data.attribute_id])
992  {
993  SGVector<float64_t> comp=leftchild->data.transit_into_values;
994  bool flag=false;
995  for (int32_t k=0;k<comp.vlen;k++)
996  {
997  if (comp[k]==sample[node->data.attribute_id])
998  {
999  flag=true;
1000  break;
1001  }
1002  }
1003 
1004  if (flag)
1005  {
1006  SG_UNREF(node);
1007  node=leftchild;
1008  SG_REF(leftchild);
1009  }
1010  else
1011  {
1012  SG_UNREF(node);
1013  node=node->right();
1014  }
1015  }
1016  else
1017  {
1018  if (sample[node->data.attribute_id]<=leftchild->data.transit_into_values[0])
1019  {
1020  SG_UNREF(node);
1021  node=leftchild;
1022  SG_REF(leftchild);
1023  }
1024  else
1025  {
1026  SG_UNREF(node);
1027  node=node->right();
1028  }
1029  }
1030 
1031  SG_UNREF(leftchild);
1032  }
1033 
1034  labels[i]=node->data.node_label;
1035  SG_UNREF(node);
1036  }
1037 
1038  switch(m_mode)
1039  {
1040  case PT_MULTICLASS:
1041  {
1042  CMulticlassLabels* mlabels=new CMulticlassLabels(labels);
1043  return mlabels;
1044  }
1045 
1046  case PT_REGRESSION:
1047  {
1048  CRegressionLabels* rlabels=new CRegressionLabels(labels);
1049  return rlabels;
1050  }
1051 
1052  default:
1053  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
1054  }
1055 
1056  return NULL;
1057 }
1058 
1060 {
1061  int32_t num_vecs=data->get_num_vectors();
1062 
1063  // divide data into V folds randomly
1064  SGVector<int32_t> subid(num_vecs);
1065  subid.random_vector(subid.vector,subid.vlen,0,folds-1);
1066 
1067  // for each fold subset
1070  SGVector<int32_t> num_alphak(folds);
1071  for (int32_t i=0;i<folds;i++)
1072  {
1073  // for chosen fold, create subset for training parameters
1074  CDynamicArray<int32_t>* test_indices=new CDynamicArray<int32_t>();
1075  CDynamicArray<int32_t>* train_indices=new CDynamicArray<int32_t>();
1076  for (int32_t j=0;j<num_vecs;j++)
1077  {
1078  if (subid[j]==i)
1079  test_indices->push_back(j);
1080  else
1081  train_indices->push_back(j);
1082  }
1083 
1084  if (test_indices->get_num_elements()==0 || train_indices->get_num_elements()==0)
1085  {
1086  SG_ERROR("Unfortunately you have reached the very low probability event where atleast one of "
1087  "the subsets in cross-validation is not represented at all. Please re-run.")
1088  }
1089 
1090  SGVector<int32_t> subset(train_indices->get_array(),train_indices->get_num_elements(),false);
1091  data->add_subset(subset);
1092  m_labels->add_subset(subset);
1093  SGVector<float64_t> subset_weights(train_indices->get_num_elements());
1094  for (int32_t j=0;j<train_indices->get_num_elements();j++)
1095  subset_weights[j]=m_weights[train_indices->get_element(j)];
1096 
1097  // train with training subset
1098  bnode_t* root=CARTtrain(data,subset_weights,m_labels,0);
1099 
1100  // prune trained tree
1102  tmax->set_root(root);
1103  CDynamicObjectArray* pruned_trees=prune_tree(tmax);
1104 
1105  data->remove_subset();
1107  subset=SGVector<int32_t>(test_indices->get_array(),test_indices->get_num_elements(),false);
1108  data->add_subset(subset);
1109  m_labels->add_subset(subset);
1110  subset_weights=SGVector<float64_t>(test_indices->get_num_elements());
1111  for (int32_t j=0;j<test_indices->get_num_elements();j++)
1112  subset_weights[j]=m_weights[test_indices->get_element(j)];
1113 
1114  // calculate R_CV values for each alpha_k using test subset and store them
1115  num_alphak[i]=m_alphas->get_num_elements();
1116  for (int32_t j=0;j<m_alphas->get_num_elements();j++)
1117  {
1118  alphak->push_back(m_alphas->get_element(j));
1119  CSGObject* jth_element=pruned_trees->get_element(j);
1120  bnode_t* current_root=NULL;
1121  if (jth_element!=NULL)
1122  current_root=dynamic_cast<bnode_t*>(jth_element);
1123  else
1124  SG_ERROR("%d element is NULL which should not be",j);
1125 
1126  CLabels* labels=apply_from_current_node(data, current_root);
1127  float64_t error=compute_error(labels, m_labels, subset_weights);
1128  r_cv->push_back(error);
1129  SG_UNREF(labels);
1130  SG_UNREF(jth_element);
1131  }
1132 
1133  data->remove_subset();
1135  SG_UNREF(train_indices);
1136  SG_UNREF(test_indices);
1137  SG_UNREF(tmax);
1138  SG_UNREF(pruned_trees);
1139  }
1140 
1141  // prune the original T_max
1142  CDynamicObjectArray* pruned_trees=prune_tree(this);
1143 
1144  // find subtree with minimum R_cv
1145  int32_t min_index=-1;
1147  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
1148  {
1149  float64_t alpha=0.;
1150  if (i==m_alphas->get_num_elements()-1)
1151  alpha=m_alphas->get_element(i)+1;
1152  else
1154 
1155  float64_t rv=0.;
1156  int32_t base=0;
1157  for (int32_t j=0;j<folds;j++)
1158  {
1159  bool flag=false;
1160  for (int32_t k=base;k<num_alphak[j]+base-1;k++)
1161  {
1162  if (alphak->get_element(k)<=alpha && alphak->get_element(k+1)>alpha)
1163  {
1164  rv+=r_cv->get_element(k);
1165  flag=true;
1166  break;
1167  }
1168  }
1169 
1170  if (!flag)
1171  rv+=r_cv->get_element(num_alphak[j]+base-1);
1172 
1173  base+=num_alphak[j];
1174  }
1175 
1176  if (rv<min_r_cv)
1177  {
1178  min_index=i;
1179  min_r_cv=rv;
1180  }
1181  }
1182 
1183  CSGObject* element=pruned_trees->get_element(min_index);
1184  bnode_t* best_tree_root=NULL;
1185  if (element!=NULL)
1186  best_tree_root=dynamic_cast<bnode_t*>(element);
1187  else
1188  SG_ERROR("%d element is NULL which should not be",min_index);
1189 
1190  this->set_root(best_tree_root);
1191 
1192  SG_UNREF(element);
1193  SG_UNREF(pruned_trees);
1194  SG_UNREF(r_cv);
1195  SG_UNREF(alphak);
1196 }
1197 
1199 {
1200  REQUIRE(labels,"input labels cannot be NULL");
1201  REQUIRE(reference,"reference labels cannot be NULL")
1202 
1203  CDenseLabels* gnd_truth=dynamic_cast<CDenseLabels*>(reference);
1204  CDenseLabels* result=dynamic_cast<CDenseLabels*>(labels);
1205 
1206  float64_t denom=weights.sum(weights);
1207  float64_t numer=0.;
1208  switch (m_mode)
1209  {
1210  case PT_MULTICLASS:
1211  {
1212  for (int32_t i=0;i<weights.vlen;i++)
1213  {
1214  if (gnd_truth->get_label(i)!=result->get_label(i))
1215  numer+=weights[i];
1216  }
1217 
1218  return numer/denom;
1219  }
1220 
1221  case PT_REGRESSION:
1222  {
1223  for (int32_t i=0;i<weights.vlen;i++)
1224  numer+=weights[i]*CMath::pow((gnd_truth->get_label(i)-result->get_label(i)),2);
1225 
1226  return numer/denom;
1227  }
1228 
1229  default:
1230  SG_ERROR("Case not possible\n");
1231  }
1232 
1233  return 0.;
1234 }
1235 
1237 {
1239  SG_UNREF(m_alphas);
1241  SG_REF(m_alphas);
1242 
1243  // base tree alpha_k=0
1244  m_alphas->push_back(0);
1246  SG_REF(t1);
1247  node_t* t1root=t1->get_root();
1248  bnode_t* t1_root=NULL;
1249  if (t1root!=NULL)
1250  t1_root=dynamic_cast<bnode_t*>(t1root);
1251  else
1252  SG_ERROR("t1_root is NULL. This is not expected\n")
1253 
1254  form_t1(t1_root);
1255  trees->push_back(t1_root);
1256  while(t1_root->data.num_leaves>1)
1257  {
1259  SG_REF(t2);
1260 
1261  node_t* t2root=t2->get_root();
1262  bnode_t* t2_root=NULL;
1263  if (t2root!=NULL)
1264  t2_root=dynamic_cast<bnode_t*>(t2root);
1265  else
1266  SG_ERROR("t1_root is NULL. This is not expected\n")
1267 
1268  float64_t a_k=find_weakest_alpha(t2_root);
1269  m_alphas->push_back(a_k);
1270  cut_weakest_link(t2_root,a_k);
1271  trees->push_back(t2_root);
1272 
1273  SG_UNREF(t1);
1274  SG_UNREF(t1_root);
1275  t1=t2;
1276  t1_root=t2_root;
1277  }
1278 
1279  SG_UNREF(t1);
1280  SG_UNREF(t1_root);
1281  return trees;
1282 }
1283 
1285 {
1286  if (node->data.num_leaves!=1)
1287  {
1288  bnode_t* left=node->left();
1289  bnode_t* right=node->right();
1290 
1291  SGVector<float64_t> weak_links(3);
1292  weak_links[0]=find_weakest_alpha(left);
1293  weak_links[1]=find_weakest_alpha(right);
1294  weak_links[2]=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1295  weak_links[2]/=(node->data.num_leaves-1.0);
1296 
1297  SG_UNREF(left);
1298  SG_UNREF(right);
1299  return CMath::min(weak_links.vector,weak_links.vlen);
1300  }
1301 
1302  return CMath::MAX_REAL_NUMBER;
1303 }
1304 
1306 {
1307  if (node->data.num_leaves==1)
1308  return;
1309 
1310  float64_t g=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1311  g/=(node->data.num_leaves-1.0);
1312  if (alpha==g)
1313  {
1314  node->data.num_leaves=1;
1315  node->data.weight_minus_branch=node->data.weight_minus_node;
1316  CDynamicObjectArray* children=new CDynamicObjectArray();
1317  node->set_children(children);
1318 
1319  SG_UNREF(children);
1320  }
1321  else
1322  {
1323  bnode_t* left=node->left();
1324  bnode_t* right=node->right();
1325  cut_weakest_link(left,alpha);
1326  cut_weakest_link(right,alpha);
1327  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1328  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1329 
1330  SG_UNREF(left);
1331  SG_UNREF(right);
1332  }
1333 }
1334 
1336 {
1337  if (node->data.num_leaves!=1)
1338  {
1339  bnode_t* left=node->left();
1340  bnode_t* right=node->right();
1341 
1342  form_t1(left);
1343  form_t1(right);
1344 
1345  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1346  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1347  if (node->data.weight_minus_node==node->data.weight_minus_branch)
1348  {
1349  node->data.num_leaves=1;
1350  CDynamicObjectArray* children=new CDynamicObjectArray();
1351  node->set_children(children);
1352 
1353  SG_UNREF(children);
1354  }
1355 
1356  SG_UNREF(left);
1357  SG_UNREF(right);
1358  }
1359 }
1360 
1362 {
1366  m_types_set=false;
1367  m_weights_set=false;
1368  m_apply_cv_pruning=false;
1369  m_folds=5;
1371  SG_REF(m_alphas);
1372  m_max_depth=0;
1373  m_min_node_size=0;
1374  m_label_epsilon=1e-7;
1375 
1376  SG_ADD(&m_nominal,"m_nominal", "feature types", MS_NOT_AVAILABLE);
1377  SG_ADD(&m_weights,"m_weights", "weights", MS_NOT_AVAILABLE);
1378  SG_ADD(&m_weights_set,"m_weights_set", "weights set", MS_NOT_AVAILABLE);
1379  SG_ADD(&m_types_set,"m_types_set", "feature types set", MS_NOT_AVAILABLE);
1380  SG_ADD(&m_apply_cv_pruning,"m_apply_cv_pruning", "apply cross validation pruning", MS_NOT_AVAILABLE);
1381  SG_ADD(&m_folds,"m_folds","number of subsets for cross validation", MS_NOT_AVAILABLE);
1382  SG_ADD(&m_max_depth,"m_max_depth","max allowed tree depth",MS_NOT_AVAILABLE)
1383  SG_ADD(&m_min_node_size,"m_min_node_size","min allowed node size",MS_NOT_AVAILABLE)
1384  SG_ADD(&m_label_epsilon,"m_label_epsilon","epsilon for labels",MS_NOT_AVAILABLE)
1385 }
void set_cv_pruning()
Definition: CARTree.h:211
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:976
bool m_types_set
Definition: CARTree.h:428
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
int32_t get_max_depth() const
Definition: CARTree.cpp:214
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:223
virtual ELabelType get_label_type() const =0
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:169
Real Labels are real-valued labels.
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:192
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
int32_t get_min_node_size() const
Definition: CARTree.cpp:225
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:84
int32_t get_num_folds() const
Definition: CARTree.cpp:203
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1236
int32_t index_t
Definition: common.h:62
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
real valued labels (e.g. for regression, classifier outputs)
Definition: LabelTypes.h:22
multi-class labels 0,1,...
Definition: LabelTypes.h:20
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:99
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1284
float64_t least_squares_deviation(SGVector< float64_t > labels, SGVector< float64_t > weights, float64_t &total_weight)
Definition: CARTree.cpp:958
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1335
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:89
T * get_array() const
Definition: DynamicArray.h:408
CLabels * m_labels
Definition: Machine.h:361
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
index_t num_cols
Definition: SGMatrix.h:378
static const float64_t EQ_DELTA
Definition: CARTree.h:415
bool m_apply_cv_pruning
Definition: CARTree.h:434
virtual ~CCARTree()
Definition: CARTree.cpp:65
CTreeMachineNode< CARTreeNodeData > * get_root()
Definition: TreeMachine.h:88
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:706
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:242
float64_t get_label(int32_t idx)
int32_t m_max_depth
Definition: CARTree.h:446
float64_t m_label_epsilon
Definition: CARTree.h:419
#define SG_REF(x)
Definition: SGObject.h:51
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:70
index_t num_rows
Definition: SGMatrix.h:376
void set_root(CTreeMachineNode< CARTreeNodeData > *root)
Definition: TreeMachine.h:78
static void qsort(T *output, int32_t size)
Definition: Math.h:1313
Multiclass Labels for multi-class classification.
static bool fequals(const T &a, const T &b, const float64_t eps, bool tolerant=false)
Definition: Math.h:335
virtual void set_children(CDynamicObjectArray *children)
index_t vlen
Definition: SGVector.h:494
float64_t gini_impurity_index(SGVector< float64_t > weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:944
void clear_feature_types()
Definition: CARTree.cpp:197
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:918
EProblemType
Definition: Machine.h:110
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
bool set_element(T e, int32_t idx1, int32_t idx2=0, int32_t idx3=0)
Definition: DynamicArray.h:306
virtual int32_t get_num_vectors() const
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:230
void right(CBinaryTreeMachineNode *r)
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:781
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:208
double float64_t
Definition: common.h:50
int32_t m_min_node_size
Definition: CARTree.h:449
int32_t m_folds
Definition: CARTree.h:437
static SGVector< index_t > argsort(SGVector< T > vector)
Definition: Math.h:1599
bool m_weights_set
Definition: CARTree.h:431
virtual void remove_subset()
Definition: Labels.cpp:49
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:175
CTreeMachine * clone_tree()
Definition: TreeMachine.h:97
void clear_weights()
Definition: CARTree.cpp:180
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
static T sum(T *vec, int32_t len)
Return sum(vec)
Definition: SGVector.h:354
virtual EFeatureClass get_feature_class() const =0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:285
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:219
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:838
structure to store data of a node of CART. This can be used as a template type in TreeMachineNode cla...
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual void remove_subset()
Definition: Features.cpp:322
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:186
CBinaryTreeMachineNode< CARTreeNodeData > bnode_t
Definition: TreeMachine.h:55
virtual int32_t compute_best_attribute(SGMatrix< float64_t > mat, SGVector< float64_t > weights, SGVector< float64_t > labels_vec, SGVector< float64_t > left, SGVector< float64_t > right, SGVector< bool > is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right)
Definition: CARTree.cpp:486
The class Features is the base class of all feature objects.
Definition: Features.h:68
static T min(T a, T b)
Definition: Math.h:157
SGVector< bool > m_nominal
Definition: CARTree.h:422
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1198
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1305
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:412
SGVector< T > clone() const
Definition: SGVector.cpp:209
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:111
static float base
Definition: JLCoverTree.h:84
int32_t get_num_elements() const
Definition: DynamicArray.h:200
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:123
CSGObject * get_element(int32_t index) const
Matrix::Scalar max(Matrix m)
Definition: Redux.h:66
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
#define SG_WARNING(...)
Definition: SGIO.h:128
#define SG_ADD(...)
Definition: SGObject.h:81
SGVector< float64_t > m_weights
Definition: CARTree.h:425
static float32_t sqrt(float32_t x)
Definition: Math.h:459
Dense integer or floating point labels.
Definition: DenseLabels.h:35
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:443
#define delta
Definition: sfa.cpp:23
void left(CBinaryTreeMachineNode *l)
static int32_t unique(T *output, int32_t size)
Definition: SGVector.cpp:787
static int32_t pow(bool x, int32_t n)
Definition: Math.h:535
static const float64_t MAX_REAL_NUMBER
Definition: Math.h:2061
const T & get_element(int32_t idx1, int32_t idx2=0, int32_t idx3=0) const
Definition: DynamicArray.h:212
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:462
static T abs(T a)
Definition: Math.h:179
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:236
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1059
static void random_vector(T *vec, int32_t len, T min_value, T max_value)
Definition: SGVector.cpp:565
static const float64_t MISSING
Definition: CARTree.h:409
EProblemType m_mode
Definition: CARTree.h:440

SHOGUN Machine Learning Toolbox - Documentation