SplittingStrategy.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/SplittingStrategy.h>
00012 #include <shogun/features/Labels.h>
00013 
00014 using namespace shogun;
00015 
00016 CSplittingStrategy::CSplittingStrategy()
00017 {
00018     init();
00019 }
00020 
00021 CSplittingStrategy::CSplittingStrategy(CLabels* labels, int32_t num_subsets)
00022 {
00023     init();
00024 
00025     /* "assert" that num_subsets is smaller than num labels */
00026     if (labels->get_num_labels()<num_subsets)
00027     {
00028         SG_ERROR("Only %d labels for %d subsets!\n", labels->get_num_labels(),
00029                 num_subsets);
00030     }
00031 
00032     /* check for "stupid" combinations of label numbers and num_subsets.
00033      * if there are of a class less labels than num_subsets, the class will not
00034      * appear in every subset, leading to subsets of only one class in the
00035      * extreme case of a two class labeling. */
00036     SGVector<index_t> labels_per_class(labels->get_num_classes());
00037     SGVector<float64_t> classes=labels->get_classes();
00038 
00039     for (index_t i=0; i<labels->get_num_classes(); ++i)
00040     {
00041         labels_per_class.vector[i]=0;
00042         for (index_t j=0; j<labels->get_num_labels(); ++j)
00043         {
00044             if (classes.vector[i]==labels->get_label(j))
00045                 labels_per_class.vector[i]++;
00046         }
00047     }
00048 
00049     for (index_t i=0; i<labels->get_num_classes(); ++i)
00050     {
00051         if (labels_per_class.vector[i]<num_subsets)
00052         {
00053             SG_WARNING("There are only %d labels of class %.18g, but %d "
00054                     "subsets. Labels of that class will not appear in every "
00055                     "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
00056         }
00057     }
00058 
00059     labels_per_class.destroy_vector();
00060     classes.destroy_vector();
00061 
00062     m_labels=labels;
00063     SG_REF(m_labels);
00064 
00065     /* construct all arrays */
00066     for (index_t i=0; i<num_subsets; ++i)
00067         m_subset_indices->append_element(new CDynamicArray<index_t> ());
00068 }
00069 
00070 void CSplittingStrategy::init()
00071 {
00072     m_labels=NULL;
00073     m_subset_indices=new CDynamicObjectArray<CDynamicArray<index_t> >();
00074     SG_REF(m_subset_indices);
00075 
00076     m_parameters->add((CSGObject**)m_labels, "labels", "Labels for subsets");
00077     m_parameters->add((CSGObject**)m_subset_indices, "subset_indices",
00078             "Set of sets of subset indices");
00079 }
00080 
00081 CSplittingStrategy::~CSplittingStrategy()
00082 {
00083     SG_UNREF(m_labels);
00084     SG_UNREF(m_subset_indices);
00085 }
00086 
00087 SGVector<index_t> CSplittingStrategy::generate_subset_indices(index_t subset_idx)
00088 {
00089     /* construct SGVector copy from index vector */
00090     CDynamicArray<index_t>* to_copy=m_subset_indices->get_element_safe(
00091             subset_idx);
00092 
00093     index_t num_elements=to_copy->get_num_elements();
00094     SGVector<index_t> result(num_elements, true);
00095 
00096     /* copy data */
00097     memcpy(result.vector, to_copy->get_array(), sizeof(index_t)*num_elements);
00098 
00099     SG_UNREF(to_copy);
00100 
00101     return result;
00102 }
00103 
00104 SGVector<index_t> CSplittingStrategy::generate_subset_inverse(index_t subset_idx)
00105 {
00106     CDynamicArray<index_t>* to_invert=m_subset_indices->get_element_safe(
00107             subset_idx);
00108 
00109     SGVector<index_t> result(
00110             m_labels->get_num_labels()-to_invert->get_num_elements(), true);
00111 
00112     index_t index=0;
00113     for (index_t i=0; i<m_labels->get_num_labels(); ++i)
00114     {
00115         /* add i to inverse indices if it is not in the to be inverted set */
00116         if (to_invert->find_element(i)==-1)
00117             result.vector[index++]=i;
00118     }
00119 
00120     SG_UNREF(to_invert);
00121 
00122     return result;
00123 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation