SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StratifiedCrossValidationSplitting.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2011-2012 Heiko Strathmann
8  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9  */
10 
12 #include <shogun/labels/Labels.h>
15 
16 using namespace shogun;
17 
20 {
21  m_rng = sg_rand;
22 }
23 
25  CLabels* labels, index_t num_subsets) :
26  CSplittingStrategy(labels, num_subsets)
27 {
28  /* check for "stupid" combinations of label numbers and num_subsets.
29  * if there are of a class less labels than num_subsets, the class will not
30  * appear in every subset, leading to subsets of only one class in the
31  * extreme case of a two class labeling. */
32  SGVector<float64_t> classes;
33 
34  int32_t num_classes=2;
35  if (labels->get_label_type() == LT_MULTICLASS)
36  {
37  num_classes=((CMulticlassLabels*) labels)->get_num_classes();
38  classes=((CMulticlassLabels*) labels)->get_unique_labels();
39  }
40  else if (labels->get_label_type() == LT_BINARY)
41  {
42  classes=SGVector<float64_t>(2);
43  classes[0]=-1;
44  classes[1]=+1;
45  }
46  else
47  {
48  SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
49  }
50 
51  SGVector<index_t> labels_per_class(num_classes);
52 
53  for (index_t i=0; i<num_classes; ++i)
54  {
55  labels_per_class.vector[i]=0;
56  for (index_t j=0; j<labels->get_num_labels(); ++j)
57  {
58  if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j))
59  labels_per_class.vector[i]++;
60  }
61  }
62 
63  for (index_t i=0; i<num_classes; ++i)
64  {
65  if (labels_per_class.vector[i]<num_subsets)
66  {
67  SG_WARNING("There are only %d labels of class %.18g, but %d "
68  "subsets. Labels of that class will not appear in every "
69  "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
70  }
71  }
72 
73  m_rng = sg_rand;
74 }
75 
77 {
78  /* ensure that subsets are empty and set flag to filled */
79  reset_subsets();
80  m_is_filled=true;
81 
82  SGVector<float64_t> unique_labels;
83 
85  {
86  unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels();
87  }
88  else if (m_labels->get_label_type() == LT_BINARY)
89  {
90  unique_labels=SGVector<float64_t>(2);
91  unique_labels[0]=-1;
92  unique_labels[1]=+1;
93  }
94  else
95  {
96  SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
97  }
98 
99  /* for every label, build set for indices */
100  CDynamicObjectArray label_indices;
101  for (index_t i=0; i<unique_labels.vlen; ++i)
102  label_indices.append_element(new CDynamicArray<index_t> ());
103 
104  /* fill set with indices, for each label type ... */
105  for (index_t i=0; i<unique_labels.vlen; ++i)
106  {
107  /* ... iterate over all labels and add indices with same label to set */
108  for (index_t j=0; j<m_labels->get_num_labels(); ++j)
109  {
110  if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i])
111  {
113  label_indices.get_element(i);
114  current->append_element(j);
115  SG_UNREF(current);
116  }
117  }
118  }
119 
120  /* shuffle created label sets */
121  for (index_t i=0; i<label_indices.get_num_elements(); ++i)
122  {
124  label_indices.get_element(i);
125 
126  // external random state important for threads
127  current->shuffle(m_rng);
128 
129  SG_UNREF(current);
130  }
131 
132  /* distribute labels to subsets for all label types */
133  index_t target_set=0;
134  for (index_t i=0; i<unique_labels.vlen; ++i)
135  {
136  /* current index set for current label */
138  label_indices.get_element(i);
139 
140  for (index_t j=0; j<current->get_num_elements(); ++j)
141  {
143  m_subset_indices->get_element(target_set++);
144  next->append_element(current->get_element(j));
145  target_set%=m_subset_indices->get_num_elements();
146  SG_UNREF(next);
147  }
148 
149  SG_UNREF(current);
150  }
151 
152  /* finally shuffle to avoid that subsets with low indices have more
153  * elements, which happens if the number of class labels is not equal to
154  * the number of subsets (external random state important for threads) */
156 }

SHOGUN Machine Learning Toolbox - Documentation