00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/StratifiedCrossValidationSplitting.h> 00012 #include <shogun/labels/Labels.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 #include <shogun/labels/MulticlassLabels.h> 00015 00016 using namespace shogun; 00017 00018 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() : 00019 CSplittingStrategy(0, 0) 00020 { 00021 } 00022 00023 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting( 00024 CLabels* labels, index_t num_subsets) : 00025 CSplittingStrategy(labels, num_subsets) 00026 { 00027 /* check for "stupid" combinations of label numbers and num_subsets. 00028 * if there are of a class less labels than num_subsets, the class will not 00029 * appear in every subset, leading to subsets of only one class in the 00030 * extreme case of a two class labeling. */ 00031 SGVector<float64_t> classes; 00032 00033 int32_t num_classes=2; 00034 if (labels->get_label_type() == LT_MULTICLASS) 00035 { 00036 num_classes=((CMulticlassLabels*) labels)->get_num_classes(); 00037 classes=((CMulticlassLabels*) labels)->get_unique_labels(); 00038 } 00039 else if (labels->get_label_type() == LT_BINARY) 00040 { 00041 classes=SGVector<float64_t>(2); 00042 classes[0]=-1; 00043 classes[1]=+1; 00044 } 00045 else 00046 { 00047 SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n"); 00048 } 00049 00050 SGVector<index_t> labels_per_class(num_classes); 00051 00052 for (index_t i=0; i<num_classes; ++i) 00053 { 00054 labels_per_class.vector[i]=0; 00055 for (index_t j=0; j<labels->get_num_labels(); ++j) 00056 { 00057 if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j)) 00058 labels_per_class.vector[i]++; 00059 } 00060 } 00061 00062 for (index_t i=0; i<num_classes; ++i) 00063 { 00064 if (labels_per_class.vector[i]<num_subsets) 00065 { 00066 SG_WARNING("There are only %d labels of class %.18g, but %d " 00067 "subsets. Labels of that class will not appear in every " 00068 "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets); 00069 } 00070 } 00071 } 00072 00073 void CStratifiedCrossValidationSplitting::build_subsets() 00074 { 00075 /* ensure that subsets are empty and set flag to filled */ 00076 reset_subsets(); 00077 m_is_filled=true; 00078 00079 SGVector<float64_t> unique_labels; 00080 00081 if (m_labels->get_label_type() == LT_MULTICLASS) 00082 { 00083 unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels(); 00084 } 00085 else if (m_labels->get_label_type() == LT_BINARY) 00086 { 00087 unique_labels=SGVector<float64_t>(2); 00088 unique_labels[0]=-1; 00089 unique_labels[1]=+1; 00090 } 00091 else 00092 { 00093 SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n"); 00094 } 00095 00096 /* for every label, build set for indices */ 00097 CDynamicObjectArray label_indices; 00098 for (index_t i=0; i<unique_labels.vlen; ++i) 00099 label_indices.append_element(new CDynamicArray<index_t> ()); 00100 00101 /* fill set with indices, for each label type ... */ 00102 for (index_t i=0; i<unique_labels.vlen; ++i) 00103 { 00104 /* ... iterate over all labels and add indices with same label to set */ 00105 for (index_t j=0; j<m_labels->get_num_labels(); ++j) 00106 { 00107 if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i]) 00108 { 00109 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00110 label_indices.get_element(i); 00111 current->append_element(j); 00112 SG_UNREF(current); 00113 } 00114 } 00115 } 00116 00117 /* shuffle created label sets */ 00118 for (index_t i=0; i<label_indices.get_num_elements(); ++i) 00119 { 00120 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00121 label_indices.get_element(i); 00122 current->shuffle(); 00123 SG_UNREF(current); 00124 } 00125 00126 /* distribute labels to subsets for all label types */ 00127 index_t target_set=0; 00128 for (index_t i=0; i<unique_labels.vlen; ++i) 00129 { 00130 /* current index set for current label */ 00131 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00132 label_indices.get_element(i); 00133 00134 for (index_t j=0; j<current->get_num_elements(); ++j) 00135 { 00136 CDynamicArray<index_t>* next=(CDynamicArray<index_t>*) 00137 m_subset_indices->get_element(target_set++); 00138 next->append_element(current->get_element(j)); 00139 target_set%=m_subset_indices->get_num_elements(); 00140 SG_UNREF(next); 00141 } 00142 00143 SG_UNREF(current); 00144 } 00145 00146 /* finally shuffle to avoid that subsets with low indices have more 00147 * elements, which happens if the number of class labels is not equal to 00148 * the number of subsets */ 00149 m_subset_indices->shuffle(); 00150 }