MulticlassLabels.cpp

Go to the documentation of this file.
00001 #include <shogun/labels/DenseLabels.h>
00002 #include <shogun/labels/BinaryLabels.h>
00003 #include <shogun/labels/MulticlassLabels.h>
00004 
00005 using namespace shogun;
00006 
00007 CMulticlassLabels::CMulticlassLabels() : CDenseLabels()
00008 {
00009     m_multiclass_confidences = NULL;
00010     m_num_multiclass_confidences = 0;
00011 }
00012 
00013 CMulticlassLabels::CMulticlassLabels(int32_t num_labels) : CDenseLabels(num_labels)
00014 {
00015     m_multiclass_confidences = SG_MALLOC(SGVector<float64_t>, num_labels);
00016     m_num_multiclass_confidences = num_labels;
00017     for (int32_t i=0; i<num_labels; i++)
00018         new (&m_multiclass_confidences[i]) SGVector<float64_t>();
00019 }
00020 
00021 CMulticlassLabels::CMulticlassLabels(const SGVector<float64_t> src) : CDenseLabels()
00022 {
00023     set_labels(src);
00024     m_multiclass_confidences = NULL;
00025     m_num_multiclass_confidences = 0;
00026 }
00027 
00028 CMulticlassLabels::CMulticlassLabels(CFile* loader) : CDenseLabels(loader)
00029 {
00030     m_multiclass_confidences = NULL;
00031     m_num_multiclass_confidences = 0;
00032 }
00033 
00034 CMulticlassLabels::~CMulticlassLabels()
00035 {
00036     for (int32_t i=0; i<m_num_multiclass_confidences; i++)
00037         m_multiclass_confidences[i].~SGVector<float64_t>();
00038     SG_FREE(m_multiclass_confidences);
00039 }
00040 
00041 void CMulticlassLabels::set_multiclass_confidences(int32_t i, SGVector<float64_t> confidences)
00042 {
00043     m_multiclass_confidences[i] = confidences;
00044 }
00045 
00046 SGVector<float64_t> CMulticlassLabels::get_multiclass_confidences(int32_t i)
00047 {
00048     return m_multiclass_confidences[i];
00049 }
00050 
00051 CMulticlassLabels* CMulticlassLabels::obtain_from_generic(CLabels* base_labels)
00052 {
00053     if ( base_labels->get_label_type() == LT_MULTICLASS )
00054         return (CMulticlassLabels*) base_labels;
00055     else
00056         SG_SERROR("base_labels must be of dynamic type CMulticlassLabels");
00057 
00058     return NULL;
00059 }
00060 
00061 void CMulticlassLabels::ensure_valid(const char* context)
00062 {       
00063     CDenseLabels::ensure_valid(context);
00064 
00065     int32_t subset_size=get_num_labels();
00066     for (int32_t i=0; i<subset_size; i++)
00067     {
00068         int32_t real_i = m_subset_stack->subset_idx_conversion(i);
00069         int32_t label = int32_t(m_labels[real_i]);
00070 
00071         if (label<0 || float64_t(label)!=m_labels[real_i])
00072         {
00073             SG_ERROR("%s%sMulticlass Labels must be in range 0...<nr_classes-1> and integers!\n",
00074                     context?context:"", context?": ":"");
00075         }
00076     }
00077 }
00078 
00079 ELabelType CMulticlassLabels::get_label_type()
00080 {
00081     return LT_MULTICLASS;
00082 }
00083 
00084 CBinaryLabels* CMulticlassLabels::get_binary_for_class(int32_t i)
00085 {
00086     SGVector<float64_t> binary_labels(get_num_labels());
00087 
00088     bool use_confidences = false;
00089     if (m_num_multiclass_confidences != 0)
00090     {
00091         if (m_multiclass_confidences[i].size())
00092             use_confidences = true;
00093     }
00094     if (use_confidences)
00095     {
00096         for (int32_t k=0; k<binary_labels.vlen; k++)
00097         {
00098             SGVector<float64_t> confs = m_multiclass_confidences[k];
00099             int32_t label = get_int_label(k);
00100             binary_labels[k] = label == i ? confs[label] : -confs[label];
00101         }
00102     }
00103     else
00104     {
00105         for (int32_t k=0; k<binary_labels.vlen; k++)
00106         {
00107             int32_t label = get_int_label(k);
00108             binary_labels[k] = label == i ? +1.0 : -1.0;
00109         }
00110     }
00111     return new CBinaryLabels(binary_labels);
00112 }
00113 
00114 SGVector<float64_t> CMulticlassLabels::get_unique_labels()
00115 {
00116     /* extract all labels (copy because of possible subset) */
00117     SGVector<float64_t> unique_labels=get_labels_copy();
00118     unique_labels.vlen=SGVector<float64_t>::unique(unique_labels.vector, unique_labels.vlen);
00119 
00120     SGVector<float64_t> result(unique_labels.vlen);
00121     memcpy(result.vector, unique_labels.vector,
00122             sizeof(float64_t)*unique_labels.vlen);
00123 
00124     return result;
00125 }
00126 
00127 
00128 int32_t CMulticlassLabels::get_num_classes()
00129 {
00130     SGVector<float64_t> unique=get_unique_labels();
00131     return unique.vlen;
00132 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation