00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h> 00012 #include <shogun/labels/BinaryLabels.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 00015 using namespace shogun; 00016 00017 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy() 00018 :CMulticlassStrategy(), m_num_machines(0) 00019 { 00020 } 00021 00022 void CMulticlassOneVsOneStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels) 00023 { 00024 CMulticlassStrategy::train_start(orig_labels, train_labels); 00025 m_num_machines=m_num_classes*(m_num_classes-1)/2; 00026 00027 m_train_pair_idx_1 = 0; 00028 m_train_pair_idx_2 = 1; 00029 } 00030 00031 bool CMulticlassOneVsOneStrategy::train_has_more() 00032 { 00033 return m_train_iter < m_num_machines; 00034 } 00035 00036 SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next() 00037 { 00038 CMulticlassStrategy::train_prepare_next(); 00039 00040 SGVector<int32_t> subset(m_orig_labels->get_num_labels()); 00041 int32_t tot=0; 00042 for (int32_t k=0; k < m_orig_labels->get_num_labels(); ++k) 00043 { 00044 if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_1) 00045 { 00046 ((CBinaryLabels*) m_train_labels)->set_label(k, +1.0); 00047 subset[tot]=k; 00048 tot++; 00049 } 00050 else if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_2) 00051 { 00052 ((CBinaryLabels*) m_train_labels)->set_label(k, -1.0); 00053 subset[tot]=k; 00054 tot++; 00055 } 00056 } 00057 00058 m_train_pair_idx_2++; 00059 if (m_train_pair_idx_2 >= m_num_classes) 00060 { 00061 m_train_pair_idx_1++; 00062 m_train_pair_idx_2=m_train_pair_idx_1+1; 00063 } 00064 00065 subset.resize_vector(tot); 00066 return subset; 00067 } 00068 00069 int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs) 00070 { 00071 int32_t s=0; 00072 SGVector<int32_t> votes(m_num_classes); 00073 SGVector<int32_t> dec_vals(m_num_classes); 00074 votes.zero(); 00075 dec_vals.zero(); 00076 00077 for (int32_t i=0; i<m_num_classes; i++) 00078 { 00079 for (int32_t j=i+1; j<m_num_classes; j++) 00080 { 00081 if (outputs[s]>0) 00082 { 00083 votes[i]++; 00084 dec_vals[i] += CMath::abs(outputs[s]); 00085 } 00086 else 00087 { 00088 votes[j]++; 00089 dec_vals[j] += CMath::abs(outputs[s]); 00090 } 00091 s++; 00092 } 00093 } 00094 00095 int32_t i_max=0; 00096 int32_t vote_max=-1; 00097 float64_t dec_val_max=-1; 00098 00099 for (int32_t i=0; i < m_num_classes; ++i) 00100 { 00101 if (votes[i] > vote_max) 00102 { 00103 i_max = i; 00104 vote_max = votes[i]; 00105 dec_val_max = dec_vals[i]; 00106 } 00107 else if (votes[i] == vote_max) 00108 { 00109 if (dec_vals[i] > dec_val_max) 00110 { 00111 i_max = i; 00112 dec_val_max = dec_vals[i]; 00113 } 00114 } 00115 } 00116 00117 return i_max; 00118 }