MulticlassMachine.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2011 Soeren Sonnenburg
00008  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
00009  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
00010  */
00011 
00012 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00013 #include <shogun/machine/LinearMachine.h>
00014 #include <shogun/machine/KernelMachine.h>
00015 #include <shogun/machine/MulticlassMachine.h>
00016 #include <shogun/base/Parameter.h>
00017 #include <shogun/labels/MulticlassLabels.h>
00018 #include <shogun/labels/RegressionLabels.h>
00019 
00020 using namespace shogun;
00021 
00022 CMulticlassMachine::CMulticlassMachine()
00023 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
00024     m_machine(NULL)
00025 {
00026     SG_REF(m_multiclass_strategy);
00027     register_parameters();
00028 }
00029 
00030 CMulticlassMachine::CMulticlassMachine(
00031         CMulticlassStrategy *strategy,
00032         CMachine* machine, CLabels* labs)
00033 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
00034 {
00035     SG_REF(strategy);
00036     set_labels(labs);
00037     SG_REF(machine);
00038     m_machine = machine;
00039     register_parameters();
00040 
00041     if (labs)
00042         init_strategy();
00043 }
00044 
00045 CMulticlassMachine::~CMulticlassMachine()
00046 {
00047     SG_UNREF(m_multiclass_strategy);
00048     SG_UNREF(m_machine);
00049 }
00050 
00051 void CMulticlassMachine::set_labels(CLabels* lab)
00052 {
00053     CMachine::set_labels(lab);
00054     if (lab)
00055         init_strategy();
00056 }
00057 
00058 void CMulticlassMachine::register_parameters()
00059 {
00060     SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
00061     SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
00062 }
00063 
00064 void CMulticlassMachine::init_strategy()
00065 {
00066     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067     m_multiclass_strategy->set_num_classes(num_classes);
00068 }
00069 
00070 CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i)
00071 {
00072     CMachine *machine = (CMachine*)m_machines->get_element(i);
00073     ASSERT(machine);
00074     CBinaryLabels* output = machine->apply_binary();
00075     SG_UNREF(machine);
00076     return output;
00077 }
00078 
00079 float64_t CMulticlassMachine::get_submachine_output(int32_t i, int32_t num)
00080 {
00081     CMachine *machine = get_machine(i);
00082     float64_t output = 0.0;
00083     // dirty hack
00084     if (dynamic_cast<CLinearMachine*>(machine))
00085         output = ((CLinearMachine*)machine)->apply_one(num);
00086     if (dynamic_cast<CKernelMachine*>(machine))
00087         output = ((CKernelMachine*)machine)->apply_one(num);
00088     SG_UNREF(machine);
00089     return output;
00090 }
00091 
00092 CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
00093 {
00094     SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
00095             get_name(), data ? data->get_name() : "NULL", data);
00096 
00097     CMulticlassLabels* return_labels=NULL;
00098 
00099     if (data)
00100         init_machines_for_apply(data);
00101     else
00102         init_machines_for_apply(NULL);
00103 
00104     if (is_ready())
00105     {
00106         /* num vectors depends on whether data is provided */
00107         int32_t num_vectors=data ? data->get_num_vectors() :
00108                 get_num_rhs_vectors();
00109 
00110         int32_t num_machines=m_machines->get_num_elements();
00111         if (num_machines <= 0)
00112             SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
00113 
00114         CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
00115         CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00116 
00117         for (int32_t i=0; i < num_machines; ++i)
00118             outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00119 
00120         SGVector<float64_t> output_for_i(num_machines);
00121         for (int32_t i=0; i<num_vectors; i++)
00122         {
00123             for (int32_t j=0; j<num_machines; j++)
00124                 output_for_i[j] = outputs[j]->get_value(i);
00125 
00126             result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
00127             result->set_multiclass_confidences(i, output_for_i.clone());
00128         }
00129 
00130         for (int32_t i=0; i < num_machines; ++i)
00131             SG_UNREF(outputs[i]);
00132 
00133         SG_FREE(outputs);
00134 
00135         return_labels=result;
00136     }
00137     else
00138         SG_ERROR("Not ready");
00139 
00140 
00141     SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
00142                 get_name(), data ? data->get_name() : "NULL", data);
00143     return return_labels;
00144 }
00145 
00146 CMulticlassMultipleOutputLabels* CMulticlassMachine::apply_multiclass_multiple_output(CFeatures* data, int32_t n_outputs)
00147 {
00148     CMulticlassMultipleOutputLabels* return_labels=NULL;
00149 
00150     if (data)
00151         init_machines_for_apply(data);
00152     else
00153         init_machines_for_apply(NULL);
00154 
00155     if (is_ready())
00156     {
00157         /* num vectors depends on whether data is provided */
00158         int32_t num_vectors=data ? data->get_num_vectors() :
00159                 get_num_rhs_vectors();
00160 
00161         int32_t num_machines=m_machines->get_num_elements();
00162         if (num_machines <= 0)
00163             SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
00164         REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available");
00165 
00166         CMulticlassMultipleOutputLabels* result=new CMulticlassMultipleOutputLabels(num_vectors);
00167         CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00168 
00169         for (int32_t i=0; i < num_machines; ++i)
00170             outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00171 
00172         SGVector<float64_t> output_for_i(num_machines);
00173         for (int32_t i=0; i<num_vectors; i++)
00174         {
00175             for (int32_t j=0; j<num_machines; j++)
00176                 output_for_i[j] = outputs[j]->get_value(i);
00177 
00178             result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
00179         }
00180 
00181         for (int32_t i=0; i < num_machines; ++i)
00182             SG_UNREF(outputs[i]);
00183 
00184         SG_FREE(outputs);
00185 
00186         return_labels=result;
00187     }
00188     else
00189         SG_ERROR("Not ready");
00190 
00191     return return_labels;
00192 }
00193 
00194 bool CMulticlassMachine::train_machine(CFeatures* data)
00195 {
00196     ASSERT(m_multiclass_strategy);
00197 
00198     if ( !data && !is_ready() )
00199         SG_ERROR("Please provide training data.\n");
00200     else
00201         init_machine_for_train(data);
00202 
00203     m_machines->reset_array();
00204     CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
00205     SG_REF(train_labels);
00206     m_machine->set_labels(train_labels);
00207 
00208     m_multiclass_strategy->train_start(CMulticlassLabels::obtain_from_generic(m_labels), train_labels);
00209     while (m_multiclass_strategy->train_has_more())
00210     {
00211         SGVector<index_t> subset=m_multiclass_strategy->train_prepare_next();
00212         if (subset.vlen)
00213         {
00214             train_labels->add_subset(subset);
00215             add_machine_subset(subset);
00216         }
00217 
00218         m_machine->train();
00219         m_machines->push_back(get_machine_from_trained(m_machine));
00220 
00221         if (subset.vlen)
00222         {
00223             train_labels->remove_subset();
00224             remove_machine_subset();
00225         }
00226     }
00227 
00228     m_multiclass_strategy->train_stop();
00229     SG_UNREF(train_labels);
00230 
00231     return true;
00232 }
00233 
00234 float64_t CMulticlassMachine::apply_one(int32_t vec_idx)
00235 {
00236     init_machines_for_apply(NULL);
00237 
00238     ASSERT(m_machines->get_num_elements()>0);
00239     SGVector<float64_t> outputs(m_machines->get_num_elements());
00240 
00241     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00242         outputs[i] = get_submachine_output(i, vec_idx);
00243 
00244     float64_t result = m_multiclass_strategy->decide_label(outputs);
00245 
00246     return result;
00247 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation