49 m_dependent_vartype=dependent_vartype;
56 m_dependent_vartype=dependent_vartype;
58 m_num_breakpoints=num_breakpoints;
67 switch (m_dependent_vartype)
76 SG_ERROR(
"Invalid dependent variable type set (%d). Problem type undefined\n",m_dependent_vartype);
84 switch (m_dependent_vartype)
93 SG_ERROR(
"Invalid dependent variable type set (%d). Problem type undefined\n",m_dependent_vartype);
101 REQUIRE(data,
"Data required for classification in apply_multiclass\n")
108 REQUIRE(data,
"Data required for regression in apply_regression\n")
140 return m_feature_types;
150 REQUIRE(((var==0)||(var==1)||(var==2)),
"Expected 0 or 1 or 2 as argument. %d received\n",var)
151 m_dependent_vartype=var;
156 REQUIRE(data,
"Data required for training\n")
161 REQUIRE(m_feature_types.
vlen==feats->get_num_features(),
"Either feature types are not set or the number of feature types specified"
162 " (%d here) is not the same as the number of features in data matrix (%d here)\n",m_feature_types.
vlen,feats->get_num_features())
166 REQUIRE(m_weights.
vlen==feats->get_num_vectors(),
"Length of weights vector (currently %d) should be same as"
167 " the number of vectors in data (presently %d)",m_weights.
vlen,feats->get_num_vectors())
177 bool updated=continuous_to_ordinal(feats);
183 feature_types_cache=m_feature_types.
clone();
184 for (int32_t i=0;i<m_feature_types.
vlen;i++)
186 if (m_feature_types[i]==2)
187 m_feature_types[i]=1;
195 m_feature_types=feature_types_cache;
202 REQUIRE(data,
"data matrix cannot be empty\n");
203 REQUIRE(labels,
"labels cannot be NULL\n");
212 if (m_dependent_vartype==2)
218 else if (m_dependent_vartype==0 || m_dependent_vartype==1)
223 int32_t
max=weights[0];
226 int32_t c=weights[0];
227 for (int32_t i=1;i<lab.
vlen;i++)
229 if (lab[i]==lab[i-1])
258 SG_ERROR(
"dependent variable type should be either 0(nominal) or 1(ordinal) or 2(continuous)\n");
270 for (int32_t v=1;v<num_vecs;v++)
272 for (int32_t f=0;f<num_feats;f++)
276 if (mat(f,v)!=mat(f,v-1))
292 if (m_max_tree_depth>0)
294 if (level==m_max_tree_depth)
299 if (m_min_node_size>1)
301 if (num_vecs<m_min_node_size)
309 for (int32_t i=0;i<num_feats;i++)
312 for (int32_t j=0;j<num_vecs;j++)
317 if (m_feature_types[i]==0)
318 cat=merge_categories_nominal(feats,labels_vec,weights,pv);
319 else if (m_feature_types[i]==1)
320 cat=merge_categories_ordinal(feats,labels_vec,weights,pv);
322 SG_ERROR(
"feature type supported are 0(nominal) and 1(ordinal). m_feature_types[%d] is set %d\n",i,m_feature_types[i])
332 if (min_pv>m_alpha_split)
337 for (int32_t i=0;i<num_vecs;i++)
338 ufeats_best[i]=mat(attr_min,i);
341 for (int32_t i=0;i<cat_min.
vlen;i++)
347 for (int32_t j=0;j<num_vecs;j++)
349 for (int32_t k=0;k<unum;k++)
351 if (mat(attr_min,j)==ufeats_best[k])
369 node_t* child=CHAIDtrain(data,subweights,labels,level+1);
375 for (int32_t j=0;j<feat_class.
vlen;j++)
377 if (feat_class[j]!=j)
387 for (int32_t k=j;k<feat_class.
vlen;k++)
389 if (feat_class[k]==j)
398 for (int32_t j=0;j<unum;j++)
424 if (ufeats[inum_cat-1]==
MISSING)
430 int32_t fnum_cat=inum_cat;
435 pv=adjusted_p_value(p_value(feats,labels,weights),2,2,1,
true);
445 int32_t cat_index_max=-1;
447 for (int32_t i=0;i<inum_cat-1;i++)
449 if (cat[i]==cat[i+1])
457 for (int32_t j=0;j<feats.
vlen;j++)
459 for (int32_t k=0;k<inum_cat;k++)
461 if (feats[j]==ufeats[k])
463 if (cat[k]==cat[cat_index])
468 else if (cat[k]==cat[cat_index+1])
487 float64_t pvalue=p_value(subfeats,sublabels,subweights);
488 if (pvalue>max_merge_pv)
491 cat_index_max=cat_index;
498 if (max_merge_pv>m_alpha_merge)
501 int32_t cat2=cat[cat_index_max+1];
502 for (int32_t i=cat_index_max+1;i<inum_cat;i++)
505 cat[i]=cat[cat_index_max];
519 for (int32_t i=0;i<feats.
vlen;i++)
527 for (int32_t j=0;j<inum_cat;j++)
529 if (feats[i]==ufeats[j])
536 bool merged=handle_missing_ordinal(cat,feats_cat,labels,weights);
540 pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat+1,fnum_cat,1,
true);
544 pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat,fnum_cat,1,
false);
555 int32_t fnum_cat=inum_cat;
574 for (int32_t i=0;i<cat.
vlen;i++)
590 for (int32_t k=0;k<feats.
vlen;k++)
592 for (int32_t l=0;l<inum_cat;l++)
594 if (feats[k]==ufeats[l])
620 float64_t pvalue=p_value(subfeats,sublabels,subweights);
621 if (pvalue>max_merge_pv)
635 if (max_merge_pv>m_alpha_merge)
638 for (int32_t i=0;i<cat.
vlen;i++)
640 if (cat2_max==cat[i])
653 for (int32_t i=0;i<feats.
vlen;i++)
655 for (int32_t j=0;j<inum_cat;j++)
657 if (feats[i]==ufeats[j])
662 pv=adjusted_p_value(p_value(feats_cat,labels,weights),inum_cat,fnum_cat,0,
false);
672 modify_data_matrix(feats);
675 return apply_from_current_node(fmat,
m_root);
683 for (int32_t i=0;i<num_vecs;i++)
709 node=
dynamic_cast<node_t*
>(el);
712 SG_ERROR(
"%d child is expected to be present. But it is NULL\n",index)
730 SG_ERROR(
"Undefined problem type\n")
741 REQUIRE(cat[cat.
vlen-1]==cat.
vlen-1,
"last category is expected to be stored for MISSING. Hence it is expected to be un-merged\n")
743 for (int32_t i=0;i<cat.
vlen-1;i++)
751 int32_t cindex_max=-1;
755 for (int32_t j=0;j<feats.
vlen;j++)
771 float64_t pvalue=p_value(subfeats,sublabels,subweights);
772 if (pvalue>max_pv_pair)
783 for (int32_t i=0;i<feats.
vlen;i++)
786 feats_copy[i]=cindex_max;
788 feats_copy[i]=feats[i];
791 float64_t pv_merged=p_value(feats_copy, labels, weights);
792 float64_t pv_unmerged=p_value(feats, labels, weights);
793 if (pv_merged>pv_unmerged)
795 cat[cat.
vlen-1]=cindex_max;
796 for (int32_t i=0;i<feats.
vlen;i++)
808 float64_t CCHAIDTree::adjusted_p_value(
float64_t up_value, int32_t inum_cat, int32_t fnum_cat, int32_t ft,
bool is_missing)
811 if (inum_cat==fnum_cat)
819 for (int32_t v=0;v<fnum_cat;v++)
822 for (int32_t j=1;j<=v;j++)
823 lterm-=CMath::log(j);
825 for (int32_t j=1;j<=fnum_cat-v;j++)
826 lterm-=CMath::log(j);
844 SG_ERROR(
"Feature type must be either 0 (nominal) or 1 (ordinal). It is currently set as %d\n",ft)
852 switch (m_dependent_vartype)
858 float64_t x2=pchi2_statistic(feat,labels,weights,r,c);
865 float64_t h2=likelihood_ratio_statistic(feat,labels,weights,r,c);
870 int32_t nf=feat.
vlen;
872 float64_t f=anova_f_statistic(feat,labels,weights,num_cat);
880 SG_ERROR(
"Dependent variable type must be either 0 or 1 or 2. It is currently set as %d\n",m_dependent_vartype)
890 for (int32_t i=0;i<labels.
vlen;i++)
891 y_bar+=labels[i]*weights[i];
893 y_bar/=weights.
sum(weights);
903 for (int32_t n=0;n<feat.
vlen;n++)
905 for (int32_t i=0;i<r;i++)
907 if (feat[n]==ufeat[i])
909 numer[i]+=weights[n]*labels[n];
910 denom[i]+=weights[n];
919 for (int32_t i=0;i<r;i++)
921 for (int32_t n=0;n<feat.
vlen;n++)
923 if (feat[n]==ufeat[i])
925 nu+=weights[n]*
CMath::pow(((numer[i]/denom[i])-y_bar),2);
926 de+=weights[n]*
CMath::pow((labels[n]-(numer[i]/denom[i])),2);
935 de/=(feat.
vlen-r-0.f);
952 for (int32_t i=0;i<feat.
vlen;i++)
956 for (int32_t j=0;j<r;j++)
958 if (feat[i]==ufeat[j])
967 for (int32_t j=0;j<c;j++)
969 if (labels[i]==ulabels[j])
977 wt(row,col)+=weights[i];
987 for (int32_t i=0;i<r;i++)
989 for (int32_t j=0;j<c;j++)
990 ret+=expmat_row_effects(i,j)*
CMath::log(expmat_row_effects(i,j)/expmat_indep(i,j));
997 int32_t &r, int32_t &c)
1009 for (int32_t i=0;i<feat.
vlen;i++)
1013 for (int32_t j=0;j<r;j++)
1015 if (feat[i]==ufeat[j])
1024 for (int32_t j=0;j<c;j++)
1026 if (labels[i]==ulabels[j])
1034 wt(row,col)+=weights[i];
1040 for (int32_t i=0;i<r;i++)
1042 for (int32_t j=0;j<c;j++)
1043 ret+=
CMath::pow((ct(i,j)-expected_cf(i,j)),2)/expected_cf(i,j);
1057 for (int32_t i=0;i<r;i++)
1060 for (int32_t j=0;j<c;j++)
1065 for (int32_t i=0;i<c;i++)
1068 for (int32_t j=0;j<r;j++)
1077 for (int32_t j=0;j<c;j++)
1080 for (int32_t i=0;i<r;i++)
1084 numer+=w_j*score[j];
1090 for (int32_t i=0;i<r;i++)
1092 for (int32_t j=0;j<c;j++)
1093 wt(i,j)=(ct(i,j)-0.f)/wt(i,j);
1107 for (int32_t i=0;i<r;i++)
1110 for (int32_t j=0;j<c;j++)
1113 alpha[i]*=(row_sum[i]-0.f)/sum;
1117 for (int32_t j=0;j<c;j++)
1120 for (int32_t i=0;i<r;i++)
1121 sum+=wt(i,j)*alpha[i]*
CMath::pow(gamma[i],(score[j]-s_bar));
1123 beta[j]=(col_sum[j]-0.f)/sum;
1129 for (int32_t i=0;i<r;i++)
1131 for (int32_t j=0;j<c;j++)
1132 m_star(i,j)=wt(i,j)*alpha[i]*beta[j]*
CMath::pow(gamma[i],score[j]-s_bar);
1135 for (int32_t i=0;i<r;i++)
1139 for (int32_t j=0;j<c;j++)
1141 numer+=(score[j]-s_bar)*(ct(i,j)-m_star(i,j));
1142 denom+=
CMath::pow((score[j]-s_bar),2)*m_star(i,j);
1149 for (int32_t i=0;i<r;i++)
1150 gamma[i]=(g[i]>0)?gamma[i]*g[i]:gamma[i];
1155 for (int32_t i=0;i<r;i++)
1157 for (int32_t j=0;j<c;j++)
1159 m_kplus(i,j)=wt(i,j)*alpha[i]*beta[j]*
CMath::pow(gamma[i],(score[j]-s_bar));
1161 if (abs_diff>max_diff)
1167 if (max_diff<epsilon)
1182 for (int32_t i=0;i<r;i++)
1185 for (int32_t j=0;j<c;j++)
1190 for (int32_t i=0;i<c;i++)
1193 for (int32_t j=0;j<r;j++)
1204 int32_t total_sum=(r<=c)?row_sum.
sum(row_sum):col_sum.
sum(col_sum);
1206 for (int32_t i=0;i<r;i++)
1208 for (int32_t j=0;j<c;j++)
1209 ret(i,j)=(row_sum[i]*col_sum[j]-0.f)/(total_sum-0.f);
1215 for (int32_t i=0;i<r;i++)
1217 for (int32_t j=0;j<c;j++)
1218 wt(i,j)=(ct(i,j)-0.f)/wt(i,j);
1231 for (int32_t i=0;i<r;i++)
1234 for (int32_t j=0;j<c;j++)
1237 alpha[i]*=(row_sum[i]-0.f)/sum;
1241 for (int32_t j=0;j<c;j++)
1244 for (int32_t i=0;i<r;i++)
1245 sum+=wt(i,j)*alpha[i];
1247 beta[j]=(col_sum[j]-0.f)/sum;
1253 for (int32_t i=0;i<r;i++)
1255 for (int32_t j=0;j<c;j++)
1257 m_kplus(i,j)=wt(i,j)*alpha[i]*beta[j];
1259 if (abs_diff>max_diff)
1266 if (max_diff<epsilon)
1280 for (int32_t i=0;i<lab.
vlen;i++)
1282 mean+=lab[i]*weights[i];
1283 total_weight+=weights[i];
1288 for (int32_t i=0;i<lab.
vlen;i++)
1289 dev+=weights[i]*(lab[i]-mean)*(lab[i]-mean);
1297 int32_t count_cont=0;
1300 if (m_feature_types[i]==2)
1307 REQUIRE(m_num_breakpoints>0,
"Number of breakpoints for continuous to ordinal conversion not set.\n")
1313 if (m_feature_types[i]==2)
1320 for (int32_t i=0;i<count_cont;i++)
1326 for (int32_t j=0;j<values.vlen;j++)
1331 for (int32_t j=0;j<m_num_breakpoints;j++)
1337 m_cont_breakpoints(j,i)=values[end_pt];
1342 m_cont_breakpoints(j,i)=values[end_pt];
1348 modify_data_matrix(feats);
1358 if (m_feature_types[i]!=2)
1364 for (int32_t k=0;k<m_num_breakpoints;k++)
1378 void CCHAIDTree::init()
1382 m_dependent_vartype=0;
1383 m_weights_set=
false;
1389 m_num_breakpoints=0;
1399 SG_ADD(&m_cont_breakpoints,
"m_cont_breakpoints",
"breakpoints in continuous attributes",
MS_NOT_AVAILABLE);
CTreeMachineNode< CHAIDTreeNodeData > node_t
void range_fill(T start=0)
static CRegressionLabels * to_regression(CLabels *base_labels)
float64_t weight_minus_node
static void fill_vector(T *vec, int32_t len, T value)
structure to store data of a node of CHAID. This can be used as a template type in TreeMachineNode cl...
virtual ELabelType get_label_type() const =0
SGVector< int32_t > feature_class
Real Labels are real-valued labels.
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
int32_t get_num_features() const
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
SGVector< float64_t > distinct_features
The class Labels models labels, i.e. class assignments of objects.
static float64_t fdistribution_cdf(float64_t x, float64_t d1, float64_t d2)
real valued labels (e.g. for regression, classifier outputs)
multi-class labels 0,1,...
virtual bool train_machine(CFeatures *data=NULL)
SGMatrix< ST > get_feature_matrix()
static const float64_t MIN_REAL_NUMBER
virtual void add_child(CTreeMachineNode *child)
int32_t get_num_elements() const
CTreeMachineNode< CHAIDTreeNodeData > * m_root
void set_dependent_vartype(int32_t var)
void set_root(CTreeMachineNode< CHAIDTreeNodeData > *root)
static void qsort(T *output, int32_t size)
Multiclass Labels for multi-class classification.
static const float64_t epsilon
Class SGObject is the base class of all shogun objects.
virtual int32_t get_num_vectors() const
virtual EProblemType get_machine_problem_type() const
virtual void remove_subset()
virtual void add_subset(SGVector< index_t > subset)
static T sum(T *vec, int32_t len)
Return sum(vec)
virtual EFeatureClass get_feature_class() const =0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
virtual CDynamicObjectArray * get_children()
SGVector< int32_t > get_feature_types() const
virtual bool is_label_valid(CLabels *lab) const
void clear_feature_types()
static int64_t nchoosek(int32_t n, int32_t k)
all of classes and functions are contained in the shogun namespace
virtual void remove_subset()
The class Features is the base class of all feature objects.
static float64_t exp(float64_t x)
void set_feature_types(SGVector< int32_t > ft)
static float64_t log(float64_t v)
static CDenseFeatures * obtain_from_generic(CFeatures *const base_features)
SGVector< T > clone() const
int32_t get_num_elements() const
SGVector< float64_t > get_weights() const
CSGObject * get_element(int32_t index) const
static const float64_t MISSING
Matrix::Scalar max(Matrix m)
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
static float64_t chi2_cdf(float64_t x, float64_t k)
Dense integer or floating point labels.
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
static CMulticlassLabels * to_multiclass(CLabels *base_labels)
static int32_t unique(T *output, int32_t size)
static int32_t pow(bool x, int32_t n)
static const float64_t MAX_REAL_NUMBER
const T & get_element(int32_t idx1, int32_t idx2=0, int32_t idx3=0) const
virtual void add_subset(SGVector< index_t > subset)
void set_weights(SGVector< float64_t > w)