00001 #include <shogun/labels/DenseLabels.h> 00002 #include <shogun/labels/BinaryLabels.h> 00003 #include <shogun/labels/MulticlassLabels.h> 00004 00005 using namespace shogun; 00006 00007 CMulticlassLabels::CMulticlassLabels() : CDenseLabels() 00008 { 00009 m_multiclass_confidences = NULL; 00010 m_num_multiclass_confidences = 0; 00011 } 00012 00013 CMulticlassLabels::CMulticlassLabels(int32_t num_labels) : CDenseLabels(num_labels) 00014 { 00015 m_multiclass_confidences = SG_MALLOC(SGVector<float64_t>, num_labels); 00016 m_num_multiclass_confidences = num_labels; 00017 for (int32_t i=0; i<num_labels; i++) 00018 new (&m_multiclass_confidences[i]) SGVector<float64_t>(); 00019 } 00020 00021 CMulticlassLabels::CMulticlassLabels(const SGVector<float64_t> src) : CDenseLabels() 00022 { 00023 set_labels(src); 00024 m_multiclass_confidences = NULL; 00025 m_num_multiclass_confidences = 0; 00026 } 00027 00028 CMulticlassLabels::CMulticlassLabels(CFile* loader) : CDenseLabels(loader) 00029 { 00030 m_multiclass_confidences = NULL; 00031 m_num_multiclass_confidences = 0; 00032 } 00033 00034 CMulticlassLabels::~CMulticlassLabels() 00035 { 00036 for (int32_t i=0; i<m_num_multiclass_confidences; i++) 00037 m_multiclass_confidences[i].~SGVector<float64_t>(); 00038 SG_FREE(m_multiclass_confidences); 00039 } 00040 00041 void CMulticlassLabels::set_multiclass_confidences(int32_t i, SGVector<float64_t> confidences) 00042 { 00043 m_multiclass_confidences[i] = confidences; 00044 } 00045 00046 SGVector<float64_t> CMulticlassLabels::get_multiclass_confidences(int32_t i) 00047 { 00048 return m_multiclass_confidences[i]; 00049 } 00050 00051 CMulticlassLabels* CMulticlassLabels::obtain_from_generic(CLabels* base_labels) 00052 { 00053 if ( base_labels->get_label_type() == LT_MULTICLASS ) 00054 return (CMulticlassLabels*) base_labels; 00055 else 00056 SG_SERROR("base_labels must be of dynamic type CMulticlassLabels"); 00057 00058 return NULL; 00059 } 00060 00061 void CMulticlassLabels::ensure_valid(const char* context) 00062 { 00063 CDenseLabels::ensure_valid(context); 00064 00065 int32_t subset_size=get_num_labels(); 00066 for (int32_t i=0; i<subset_size; i++) 00067 { 00068 int32_t real_i = m_subset_stack->subset_idx_conversion(i); 00069 int32_t label = int32_t(m_labels[real_i]); 00070 00071 if (label<0 || float64_t(label)!=m_labels[real_i]) 00072 { 00073 SG_ERROR("%s%sMulticlass Labels must be in range 0...<nr_classes-1> and integers!\n", 00074 context?context:"", context?": ":""); 00075 } 00076 } 00077 } 00078 00079 ELabelType CMulticlassLabels::get_label_type() 00080 { 00081 return LT_MULTICLASS; 00082 } 00083 00084 CBinaryLabels* CMulticlassLabels::get_binary_for_class(int32_t i) 00085 { 00086 SGVector<float64_t> binary_labels(get_num_labels()); 00087 00088 bool use_confidences = false; 00089 if (m_num_multiclass_confidences != 0) 00090 { 00091 if (m_multiclass_confidences[i].size()) 00092 use_confidences = true; 00093 } 00094 if (use_confidences) 00095 { 00096 for (int32_t k=0; k<binary_labels.vlen; k++) 00097 { 00098 SGVector<float64_t> confs = m_multiclass_confidences[k]; 00099 int32_t label = get_int_label(k); 00100 binary_labels[k] = label == i ? confs[label] : -confs[label]; 00101 } 00102 } 00103 else 00104 { 00105 for (int32_t k=0; k<binary_labels.vlen; k++) 00106 { 00107 int32_t label = get_int_label(k); 00108 binary_labels[k] = label == i ? +1.0 : -1.0; 00109 } 00110 } 00111 return new CBinaryLabels(binary_labels); 00112 } 00113 00114 SGVector<float64_t> CMulticlassLabels::get_unique_labels() 00115 { 00116 /* extract all labels (copy because of possible subset) */ 00117 SGVector<float64_t> unique_labels=get_labels_copy(); 00118 unique_labels.vlen=SGVector<float64_t>::unique(unique_labels.vector, unique_labels.vlen); 00119 00120 SGVector<float64_t> result(unique_labels.vlen); 00121 memcpy(result.vector, unique_labels.vector, 00122 sizeof(float64_t)*unique_labels.vlen); 00123 00124 return result; 00125 } 00126 00127 00128 int32_t CMulticlassLabels::get_num_classes() 00129 { 00130 SGVector<float64_t> unique=get_unique_labels(); 00131 return unique.vlen; 00132 }