ECOCStrategy.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/multiclass/ecoc/ECOCStrategy.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CECOCStrategy::CECOCStrategy()
00018 {
00019     init();
00020 }
00021 
00022 CECOCStrategy::CECOCStrategy(CECOCEncoder *encoder, CECOCDecoder *decoder)
00023     :m_encoder(encoder), m_decoder(decoder)
00024 {
00025     init();
00026 }
00027 
00028 void CECOCStrategy::init()
00029 {
00030     SG_REF(m_encoder);
00031     SG_REF(m_decoder);
00032 
00033     SG_ADD((CSGObject **)&m_encoder, "encoder", "ECOC Encoder", MS_NOT_AVAILABLE);
00034     SG_ADD((CSGObject **)&m_decoder, "decoder", "ECOC Decoder", MS_NOT_AVAILABLE);
00035 }
00036 
00037 CECOCStrategy::~CECOCStrategy()
00038 {
00039     SG_UNREF(m_encoder);
00040     SG_UNREF(m_decoder);
00041 }
00042 
00043 void CECOCStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
00044 {
00045     CMulticlassStrategy::train_start(orig_labels, train_labels);
00046 
00047     m_codebook = m_encoder->create_codebook(m_num_classes);
00048 }
00049 
00050 bool CECOCStrategy::train_has_more()
00051 {
00052     return m_train_iter < m_codebook.num_rows;
00053 }
00054 
00055 SGVector<int32_t> CECOCStrategy::train_prepare_next()
00056 {
00057     SGVector<int32_t> subset(m_orig_labels->get_num_labels(), false);
00058     int32_t tot=0;
00059     for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i)
00060     {
00061         int32_t label = ((CMulticlassLabels*) m_orig_labels)->get_int_label(i);
00062         switch (m_codebook(m_train_iter, label))
00063         {
00064         case -1:
00065             ((CBinaryLabels*) m_train_labels)->set_label(i, -1);
00066             subset[tot++]=i;
00067             break;
00068         case 1:
00069             ((CBinaryLabels*) m_train_labels)->set_label(i, 1);
00070             subset[tot++]=i;
00071             break;
00072         default:
00073             // 0 means ignore
00074             break;
00075         }
00076     }
00077 
00078     CMulticlassStrategy::train_prepare_next();
00079     return SGVector<int32_t>(subset.vector, tot, true);
00080 }
00081 
00082 int32_t CECOCStrategy::decide_label(SGVector<float64_t> outputs)
00083 {
00084     return m_decoder->decide_label(outputs, m_codebook);
00085 }
00086 
00087 int32_t CECOCStrategy::get_num_machines()
00088 {
00089     return m_codebook.num_cols;
00090 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation