MulticlassLibLinear.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Sergey Lisitsyn
00008  * Copyright (C) 2012 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 #include <shogun/multiclass/MulticlassLibLinear.h>
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/lib/v_array.h>
00016 #include <shogun/lib/Signal.h>
00017 #include <shogun/labels/MulticlassLabels.h>
00018 
00019 using namespace shogun;
00020 
00021 CMulticlassLibLinear::CMulticlassLibLinear() :
00022     CLinearMulticlassMachine()
00023 {
00024     init_defaults();
00025 }
00026 
00027 CMulticlassLibLinear::CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs) :
00028     CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs)
00029 {
00030     init_defaults();
00031     set_C(C);
00032 }
00033 
00034 void CMulticlassLibLinear::init_defaults()
00035 {
00036     set_C(1.0);
00037     set_epsilon(1e-2);
00038     set_max_iter(10000);
00039     set_use_bias(false);
00040     set_save_train_state(false);
00041     m_train_state = NULL;
00042 }
00043 
00044 void CMulticlassLibLinear::register_parameters()
00045 {
00046     SG_ADD(&m_C, "m_C", "regularization constant",MS_AVAILABLE);
00047     SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
00048     SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
00049     SG_ADD(&m_use_bias, "m_use_bias", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00050     SG_ADD(&m_save_train_state, "m_save_train_state", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00051 }
00052 
00053 CMulticlassLibLinear::~CMulticlassLibLinear()
00054 {
00055     reset_train_state();
00056 }
00057 
00058 SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const
00059 {
00060     if (!m_train_state)
00061         SG_ERROR("Please enable save_train_state option and train machine.\n");
00062 
00063     ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS);
00064 
00065     int32_t num_vectors = m_features->get_num_vectors();
00066     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067 
00068     v_array<int32_t> nz_idxs;
00069     nz_idxs.reserve(num_vectors);
00070 
00071     for (int32_t i=0; i<num_vectors; i++)
00072     {
00073         for (int32_t y=0; y<num_classes; y++)
00074         {
00075             if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6)
00076             {
00077                 nz_idxs.push(i);
00078                 break;
00079             }
00080         }
00081     }
00082     int32_t num_nz = nz_idxs.index();
00083     nz_idxs.reserve(num_nz);
00084     return SGVector<int32_t>(nz_idxs.begin,num_nz);
00085 }
00086 
00087 SGMatrix<float64_t> CMulticlassLibLinear::obtain_regularizer_matrix() const
00088 {
00089     return SGMatrix<float64_t>();
00090 }
00091 
00092 bool CMulticlassLibLinear::train_machine(CFeatures* data)
00093 {
00094     if (data)
00095         set_features((CDotFeatures*)data);
00096 
00097     ASSERT(m_features);
00098     ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS);
00099     ASSERT(m_multiclass_strategy);
00100 
00101     int32_t num_vectors = m_features->get_num_vectors();
00102     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00103     int32_t bias_n = m_use_bias ? 1 : 0;
00104 
00105     problem mc_problem;
00106     mc_problem.l = num_vectors;
00107     mc_problem.n = m_features->get_dim_feature_space() + bias_n;
00108     mc_problem.y = SG_MALLOC(float64_t, mc_problem.l);
00109     for (int32_t i=0; i<num_vectors; i++)
00110         mc_problem.y[i] = ((CMulticlassLabels*) m_labels)->get_int_label(i);
00111 
00112     mc_problem.x = m_features;
00113     mc_problem.use_bias = m_use_bias;
00114 
00115     SGMatrix<float64_t> w0 = obtain_regularizer_matrix();
00116 
00117     if (!m_train_state)
00118         m_train_state = new mcsvm_state();
00119 
00120     float64_t* C = SG_MALLOC(float64_t, num_vectors);
00121     for (int32_t i=0; i<num_vectors; i++)
00122         C[i] = m_C;
00123 
00124     CSignal::clear_cancel();
00125 
00126     Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,m_epsilon,
00127                            m_max_iter,m_max_train_time,m_train_state);
00128     solver.solve();
00129 
00130     m_machines->reset_array();
00131     for (int32_t i=0; i<num_classes; i++)
00132     {
00133         CLinearMachine* machine = new CLinearMachine();
00134         SGVector<float64_t> cw(mc_problem.n-bias_n);
00135 
00136         for (int32_t j=0; j<mc_problem.n-bias_n; j++)
00137             cw[j] = m_train_state->w[j*num_classes+i];
00138 
00139         machine->set_w(cw);
00140 
00141         if (m_use_bias)
00142             machine->set_bias(m_train_state->w[(mc_problem.n-bias_n)*num_classes+i]);
00143 
00144         m_machines->push_back(machine);
00145     }
00146 
00147     if (!m_save_train_state)
00148         reset_train_state();
00149 
00150     SG_FREE(C);
00151     SG_FREE(mc_problem.y);
00152 
00153     return true;
00154 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation