StratifiedCrossValidationSplitting.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
00012 #include <shogun/labels/Labels.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 #include <shogun/labels/MulticlassLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() :
00019     CSplittingStrategy(0, 0)
00020 {
00021 }
00022 
00023 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting(
00024         CLabels* labels, index_t num_subsets) :
00025     CSplittingStrategy(labels, num_subsets)
00026 {
00027     /* check for "stupid" combinations of label numbers and num_subsets.
00028      * if there are of a class less labels than num_subsets, the class will not
00029      * appear in every subset, leading to subsets of only one class in the
00030      * extreme case of a two class labeling. */
00031     SGVector<float64_t> classes;
00032 
00033     int32_t num_classes=2;
00034     if (labels->get_label_type() == LT_MULTICLASS)
00035     {
00036         num_classes=((CMulticlassLabels*) labels)->get_num_classes();
00037         classes=((CMulticlassLabels*) labels)->get_unique_labels();
00038     }
00039     else if (labels->get_label_type() == LT_BINARY)
00040     {
00041         classes=SGVector<float64_t>(2);
00042         classes[0]=-1;
00043         classes[1]=+1;
00044     }
00045     else
00046     {
00047         SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n");
00048     }
00049 
00050     SGVector<index_t> labels_per_class(num_classes);
00051 
00052     for (index_t i=0; i<num_classes; ++i)
00053     {
00054         labels_per_class.vector[i]=0;
00055         for (index_t j=0; j<labels->get_num_labels(); ++j)
00056         {
00057             if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j))
00058                 labels_per_class.vector[i]++;
00059         }
00060     }
00061 
00062     for (index_t i=0; i<num_classes; ++i)
00063     {
00064         if (labels_per_class.vector[i]<num_subsets)
00065         {
00066             SG_WARNING("There are only %d labels of class %.18g, but %d "
00067                     "subsets. Labels of that class will not appear in every "
00068                     "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
00069         }
00070     }
00071 }
00072 
00073 void CStratifiedCrossValidationSplitting::build_subsets()
00074 {
00075     /* ensure that subsets are empty and set flag to filled */
00076     reset_subsets();
00077     m_is_filled=true;
00078 
00079     SGVector<float64_t> unique_labels;
00080 
00081     if (m_labels->get_label_type() == LT_MULTICLASS)
00082     {
00083         unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels();
00084     }
00085     else if (m_labels->get_label_type() == LT_BINARY)
00086     {
00087         unique_labels=SGVector<float64_t>(2);
00088         unique_labels[0]=-1;
00089         unique_labels[1]=+1;
00090     }
00091     else
00092     {
00093         SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n");
00094     }
00095 
00096     /* for every label, build set for indices */
00097     CDynamicObjectArray label_indices;
00098     for (index_t i=0; i<unique_labels.vlen; ++i)
00099         label_indices.append_element(new CDynamicArray<index_t> ());
00100 
00101     /* fill set with indices, for each label type ... */
00102     for (index_t i=0; i<unique_labels.vlen; ++i)
00103     {
00104         /* ... iterate over all labels and add indices with same label to set */
00105         for (index_t j=0; j<m_labels->get_num_labels(); ++j)
00106         {
00107             if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i])
00108             {
00109                 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00110                         label_indices.get_element(i);
00111                 current->append_element(j);
00112                 SG_UNREF(current);
00113             }
00114         }
00115     }
00116 
00117     /* shuffle created label sets */
00118     for (index_t i=0; i<label_indices.get_num_elements(); ++i)
00119     {
00120         CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00121                 label_indices.get_element(i);
00122         current->shuffle();
00123         SG_UNREF(current);
00124     }
00125 
00126     /* distribute labels to subsets for all label types */
00127     index_t target_set=0;
00128     for (index_t i=0; i<unique_labels.vlen; ++i)
00129     {
00130         /* current index set for current label */
00131         CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00132                 label_indices.get_element(i);
00133 
00134         for (index_t j=0; j<current->get_num_elements(); ++j)
00135         {
00136             CDynamicArray<index_t>* next=(CDynamicArray<index_t>*)
00137                     m_subset_indices->get_element(target_set++);
00138             next->append_element(current->get_element(j));
00139             target_set%=m_subset_indices->get_num_elements();
00140             SG_UNREF(next);
00141         }
00142 
00143         SG_UNREF(current);
00144     }
00145 
00146     /* finally shuffle to avoid that subsets with low indices have more
00147      * elements, which happens if the number of class labels is not equal to
00148      * the number of subsets */
00149     m_subset_indices->shuffle();
00150 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation