Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/ecoc/ECOCStrategy.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014
00015 using namespace shogun;
00016
00017 CECOCStrategy::CECOCStrategy()
00018 {
00019 init();
00020 }
00021
00022 CECOCStrategy::CECOCStrategy(CECOCEncoder *encoder, CECOCDecoder *decoder)
00023 :m_encoder(encoder), m_decoder(decoder)
00024 {
00025 init();
00026 }
00027
00028 void CECOCStrategy::init()
00029 {
00030 SG_REF(m_encoder);
00031 SG_REF(m_decoder);
00032
00033 SG_ADD((CSGObject **)&m_encoder, "encoder", "ECOC Encoder", MS_NOT_AVAILABLE);
00034 SG_ADD((CSGObject **)&m_decoder, "decoder", "ECOC Decoder", MS_NOT_AVAILABLE);
00035 }
00036
00037 CECOCStrategy::~CECOCStrategy()
00038 {
00039 SG_UNREF(m_encoder);
00040 SG_UNREF(m_decoder);
00041 }
00042
00043 void CECOCStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
00044 {
00045 CMulticlassStrategy::train_start(orig_labels, train_labels);
00046
00047 m_codebook = m_encoder->create_codebook(m_num_classes);
00048 }
00049
00050 bool CECOCStrategy::train_has_more()
00051 {
00052 return m_train_iter < m_codebook.num_rows;
00053 }
00054
00055 SGVector<int32_t> CECOCStrategy::train_prepare_next()
00056 {
00057 SGVector<int32_t> subset(m_orig_labels->get_num_labels(), false);
00058 int32_t tot=0;
00059 for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i)
00060 {
00061 int32_t label = ((CMulticlassLabels*) m_orig_labels)->get_int_label(i);
00062 switch (m_codebook(m_train_iter, label))
00063 {
00064 case -1:
00065 ((CBinaryLabels*) m_train_labels)->set_label(i, -1);
00066 subset[tot++]=i;
00067 break;
00068 case 1:
00069 ((CBinaryLabels*) m_train_labels)->set_label(i, 1);
00070 subset[tot++]=i;
00071 break;
00072 default:
00073
00074 break;
00075 }
00076 }
00077
00078 CMulticlassStrategy::train_prepare_next();
00079 return SGVector<int32_t>(subset.vector, tot, true);
00080 }
00081
00082 int32_t CECOCStrategy::decide_label(SGVector<float64_t> outputs)
00083 {
00084 return m_decoder->decide_label(outputs, m_codebook);
00085 }
00086
00087 int32_t CECOCStrategy::get_num_machines()
00088 {
00089 return m_codebook.num_cols;
00090 }