00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/labels/MulticlassMultipleOutputLabels.h> 00011 00012 using namespace shogun; 00013 00014 CMulticlassMultipleOutputLabels::CMulticlassMultipleOutputLabels() 00015 : CLabels() 00016 { 00017 init(); 00018 } 00019 00020 CMulticlassMultipleOutputLabels::CMulticlassMultipleOutputLabels(int32_t num_labels) 00021 : CLabels() 00022 { 00023 init(); 00024 m_labels = SG_MALLOC(SGVector<index_t>, num_labels); 00025 m_n_labels = num_labels; 00026 for (int32_t i=0; i<m_n_labels; i++) 00027 new (&m_labels[i]) SGVector<index_t>(); 00028 } 00029 00030 CMulticlassMultipleOutputLabels::~CMulticlassMultipleOutputLabels() 00031 { 00032 for (int32_t i=0; i<m_n_labels; i++) 00033 m_labels[i].~SGVector<index_t>(); 00034 SG_FREE(m_labels); 00035 } 00036 00037 CMulticlassMultipleOutputLabels* CMulticlassMultipleOutputLabels::obtain_from_generic(CLabels* base_labels) 00038 { 00039 if (base_labels->get_label_type() == LT_MULTICLASS_MULTIPLE_OUTPUT) 00040 return (CMulticlassMultipleOutputLabels*) base_labels; 00041 else 00042 SG_SERROR("base_labels must be of dynamic type CMulticlassMultipleOutputLabels\n"); 00043 00044 return NULL; 00045 } 00046 00047 void CMulticlassMultipleOutputLabels::ensure_valid(const char* context) 00048 { 00049 if ( m_labels == NULL ) 00050 SG_ERROR("Non-valid MulticlassMultipleOutputLabels in %s", context); 00051 } 00052 00053 SGMatrix<index_t> CMulticlassMultipleOutputLabels::get_labels() const 00054 { 00055 if (m_n_labels==0) 00056 return SGMatrix<index_t>(); 00057 int n_outputs = m_labels[0].vlen; 00058 SGMatrix<index_t> labels(m_n_labels,n_outputs); 00059 for (int32_t i=0; i<m_n_labels; i++) 00060 { 00061 for (int32_t j=0; j<n_outputs; j++) 00062 labels(i,j) = m_labels[i][j]; 00063 } 00064 return labels; 00065 } 00066 00067 SGVector<index_t> CMulticlassMultipleOutputLabels::get_label(int32_t idx) 00068 { 00069 ensure_valid("CMulticlassMultipleOutputLabels::get_label(int32_t)"); 00070 if ( idx < 0 || idx >= get_num_labels() ) 00071 SG_ERROR("Index must be inside [0, num_labels-1]\n"); 00072 00073 return m_labels[m_subset_stack->subset_idx_conversion(idx)]; 00074 } 00075 00076 bool CMulticlassMultipleOutputLabels::set_label(int32_t idx, SGVector<index_t> label) 00077 { 00078 int32_t real_idx = m_subset_stack->subset_idx_conversion(idx); 00079 00080 if (real_idx < get_num_labels()) 00081 { 00082 m_labels[real_idx] = label; 00083 return true; 00084 } 00085 else 00086 return false; 00087 } 00088 00089 int32_t CMulticlassMultipleOutputLabels::get_num_labels() 00090 { 00091 return m_n_labels; 00092 } 00093 00094 void CMulticlassMultipleOutputLabels::init() 00095 { 00096 m_labels = NULL; 00097 m_n_labels = 0; 00098 }