MulticlassOneVsOneStrategy.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
00018     :CMulticlassStrategy(), m_num_machines(0)
00019 {
00020 }
00021 
00022 void CMulticlassOneVsOneStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
00023 {
00024     CMulticlassStrategy::train_start(orig_labels, train_labels);
00025     m_num_machines=m_num_classes*(m_num_classes-1)/2;
00026 
00027     m_train_pair_idx_1 = 0;
00028     m_train_pair_idx_2 = 1;
00029 }
00030 
00031 bool CMulticlassOneVsOneStrategy::train_has_more()
00032 {
00033     return m_train_iter < m_num_machines;
00034 }
00035 
00036 SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
00037 {
00038     CMulticlassStrategy::train_prepare_next();
00039 
00040     SGVector<int32_t> subset(m_orig_labels->get_num_labels());
00041     int32_t tot=0;
00042     for (int32_t k=0; k < m_orig_labels->get_num_labels(); ++k)
00043     {
00044         if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_1)
00045         {
00046             ((CBinaryLabels*) m_train_labels)->set_label(k, +1.0);
00047             subset[tot]=k;
00048             tot++;
00049         }
00050         else if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_2)
00051         {
00052             ((CBinaryLabels*) m_train_labels)->set_label(k, -1.0);
00053             subset[tot]=k;
00054             tot++;
00055         }
00056     }
00057 
00058     m_train_pair_idx_2++;
00059     if (m_train_pair_idx_2 >= m_num_classes)
00060     {
00061         m_train_pair_idx_1++;
00062         m_train_pair_idx_2=m_train_pair_idx_1+1;
00063     }
00064 
00065     subset.resize_vector(tot);
00066     return subset;
00067 }
00068 
00069 int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs)
00070 {
00071     int32_t s=0;
00072     SGVector<int32_t> votes(m_num_classes);
00073     SGVector<int32_t> dec_vals(m_num_classes);
00074     votes.zero();
00075     dec_vals.zero();
00076 
00077     for (int32_t i=0; i<m_num_classes; i++)
00078     {
00079         for (int32_t j=i+1; j<m_num_classes; j++)
00080         {
00081             if (outputs[s]>0)
00082             {
00083                 votes[i]++;
00084                 dec_vals[i] += CMath::abs(outputs[s]);
00085             }
00086             else
00087             {
00088                 votes[j]++;
00089                 dec_vals[j] += CMath::abs(outputs[s]);
00090             }
00091             s++;
00092         }
00093     }
00094 
00095     int32_t i_max=0;
00096     int32_t vote_max=-1;
00097     float64_t dec_val_max=-1;
00098 
00099     for (int32_t i=0; i < m_num_classes; ++i)
00100     {
00101         if (votes[i] > vote_max)
00102         {
00103             i_max = i;
00104             vote_max = votes[i];
00105             dec_val_max = dec_vals[i];
00106         }
00107         else if (votes[i] == vote_max)
00108         {
00109             if (dec_vals[i] > dec_val_max)
00110             {
00111                 i_max = i;
00112                 dec_val_max = dec_vals[i];
00113             }
00114         }
00115     }
00116 
00117     return i_max;
00118 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation