SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StratifiedCrossValidationSplitting.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/labels/Labels.h>
15 
16 using namespace shogun;
17 
19  CSplittingStrategy(0, 0)
20 {
21 }
22 
24  CLabels* labels, index_t num_subsets) :
25  CSplittingStrategy(labels, num_subsets)
26 {
27  /* check for "stupid" combinations of label numbers and num_subsets.
28  * if there are of a class less labels than num_subsets, the class will not
29  * appear in every subset, leading to subsets of only one class in the
30  * extreme case of a two class labeling. */
31  SGVector<float64_t> classes;
32 
33  int32_t num_classes=2;
34  if (labels->get_label_type() == LT_MULTICLASS)
35  {
36  num_classes=((CMulticlassLabels*) labels)->get_num_classes();
37  classes=((CMulticlassLabels*) labels)->get_unique_labels();
38  }
39  else if (labels->get_label_type() == LT_BINARY)
40  {
41  classes=SGVector<float64_t>(2);
42  classes[0]=-1;
43  classes[1]=+1;
44  }
45  else
46  {
47  SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n");
48  }
49 
50  SGVector<index_t> labels_per_class(num_classes);
51 
52  for (index_t i=0; i<num_classes; ++i)
53  {
54  labels_per_class.vector[i]=0;
55  for (index_t j=0; j<labels->get_num_labels(); ++j)
56  {
57  if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j))
58  labels_per_class.vector[i]++;
59  }
60  }
61 
62  for (index_t i=0; i<num_classes; ++i)
63  {
64  if (labels_per_class.vector[i]<num_subsets)
65  {
66  SG_WARNING("There are only %d labels of class %.18g, but %d "
67  "subsets. Labels of that class will not appear in every "
68  "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
69  }
70  }
71 }
72 
74 {
75  /* ensure that subsets are empty and set flag to filled */
76  reset_subsets();
77  m_is_filled=true;
78 
79  SGVector<float64_t> unique_labels;
80 
82  {
83  unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels();
84  }
85  else if (m_labels->get_label_type() == LT_BINARY)
86  {
87  unique_labels=SGVector<float64_t>(2);
88  unique_labels[0]=-1;
89  unique_labels[1]=+1;
90  }
91  else
92  {
93  SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n");
94  }
95 
96  /* for every label, build set for indices */
97  CDynamicObjectArray label_indices;
98  for (index_t i=0; i<unique_labels.vlen; ++i)
99  label_indices.append_element(new CDynamicArray<index_t> ());
100 
101  /* fill set with indices, for each label type ... */
102  for (index_t i=0; i<unique_labels.vlen; ++i)
103  {
104  /* ... iterate over all labels and add indices with same label to set */
105  for (index_t j=0; j<m_labels->get_num_labels(); ++j)
106  {
107  if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i])
108  {
110  label_indices.get_element(i);
111  current->append_element(j);
112  SG_UNREF(current);
113  }
114  }
115  }
116 
117  /* shuffle created label sets */
118  for (index_t i=0; i<label_indices.get_num_elements(); ++i)
119  {
121  label_indices.get_element(i);
122  current->shuffle();
123  SG_UNREF(current);
124  }
125 
126  /* distribute labels to subsets for all label types */
127  index_t target_set=0;
128  for (index_t i=0; i<unique_labels.vlen; ++i)
129  {
130  /* current index set for current label */
132  label_indices.get_element(i);
133 
134  for (index_t j=0; j<current->get_num_elements(); ++j)
135  {
137  m_subset_indices->get_element(target_set++);
138  next->append_element(current->get_element(j));
139  target_set%=m_subset_indices->get_num_elements();
140  SG_UNREF(next);
141  }
142 
143  SG_UNREF(current);
144  }
145 
146  /* finally shuffle to avoid that subsets with low indices have more
147  * elements, which happens if the number of class labels is not equal to
148  * the number of subsets */
150 }

SHOGUN Machine Learning Toolbox - Documentation