00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/StratifiedCrossValidationSplitting.h> 00012 #include <shogun/features/Labels.h> 00013 #include <shogun/lib/Set.h> 00014 00015 using namespace shogun; 00016 00017 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() : 00018 CSplittingStrategy(0, 0) 00019 { 00020 } 00021 00022 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting( 00023 CLabels* labels, index_t num_subsets) : 00024 CSplittingStrategy(labels, num_subsets) 00025 { 00026 build_subsets(); 00027 } 00028 00029 void CStratifiedCrossValidationSplitting::build_subsets() 00030 { 00031 /* extract all labels */ 00032 CSet<float64_t> unique_labels; 00033 for (index_t i=0; i<m_labels->get_num_labels(); ++i) 00034 unique_labels.add(m_labels->get_label(i)); 00035 00036 /* for every label, build set for indices */ 00037 CDynamicObjectArray<CDynamicArray<index_t> > label_indices; 00038 for (index_t i=0; i<unique_labels.get_num_elements(); ++i) 00039 label_indices.append_element(new CDynamicArray<index_t> ()); 00040 00041 /* fill set with indices, for each label type ... */ 00042 for (index_t i=0; i<unique_labels.get_num_elements(); ++i) 00043 { 00044 /* ... iterate over all labels and add indices with same label to set */ 00045 for (index_t j=0; j<m_labels->get_num_labels(); ++j) 00046 { 00047 if (m_labels->get_label(j)==unique_labels[i]) 00048 { 00049 CDynamicArray<index_t>* current=label_indices.get_element(i); 00050 current->append_element(j); 00051 SG_UNREF(current); 00052 } 00053 } 00054 } 00055 00056 /* shuffle created label sets */ 00057 for (index_t i=0; i<label_indices.get_num_elements(); ++i) 00058 { 00059 CDynamicArray<index_t>* current=label_indices.get_element(i); 00060 current->shuffle(); 00061 SG_UNREF(current); 00062 } 00063 00064 /* distribute labels to subsets for all label types */ 00065 index_t target_set=0; 00066 for (index_t i=0; i<unique_labels.get_num_elements(); ++i) 00067 { 00068 /* current index set for current label */ 00069 CDynamicArray<index_t>* current=label_indices.get_element(i); 00070 00071 for (index_t j=0; j<current->get_num_elements(); ++j) 00072 { 00073 CDynamicArray<index_t>* next=m_subset_indices->get_element( 00074 target_set++); 00075 next->append_element(current->get_element(j)); 00076 target_set%=m_subset_indices->get_num_elements(); 00077 SG_UNREF(next); 00078 } 00079 00080 SG_UNREF(current); 00081 } 00082 00083 /* finally shuffle to avoid that subsets with low indices have more 00084 * elements, which happens if the number of class labels is not equal to 00085 * the number of subsets */ 00086 m_subset_indices->shuffle(); 00087 }