Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/config.h>
00012 #include <shogun/multiclass/MulticlassLibLinear.h>
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/lib/v_array.h>
00016 #include <shogun/lib/Signal.h>
00017 #include <shogun/labels/MulticlassLabels.h>
00018
00019 using namespace shogun;
00020
00021 CMulticlassLibLinear::CMulticlassLibLinear() :
00022 CLinearMulticlassMachine()
00023 {
00024 init_defaults();
00025 }
00026
00027 CMulticlassLibLinear::CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs) :
00028 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs)
00029 {
00030 init_defaults();
00031 set_C(C);
00032 }
00033
00034 void CMulticlassLibLinear::init_defaults()
00035 {
00036 set_C(1.0);
00037 set_epsilon(1e-2);
00038 set_max_iter(10000);
00039 set_use_bias(false);
00040 set_save_train_state(false);
00041 m_train_state = NULL;
00042 }
00043
00044 void CMulticlassLibLinear::register_parameters()
00045 {
00046 SG_ADD(&m_C, "m_C", "regularization constant",MS_AVAILABLE);
00047 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
00048 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
00049 SG_ADD(&m_use_bias, "m_use_bias", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00050 SG_ADD(&m_save_train_state, "m_save_train_state", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00051 }
00052
00053 CMulticlassLibLinear::~CMulticlassLibLinear()
00054 {
00055 reset_train_state();
00056 }
00057
00058 SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const
00059 {
00060 if (!m_train_state)
00061 SG_ERROR("Please enable save_train_state option and train machine.\n");
00062
00063 ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS);
00064
00065 int32_t num_vectors = m_features->get_num_vectors();
00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067
00068 v_array<int32_t> nz_idxs;
00069 nz_idxs.reserve(num_vectors);
00070
00071 for (int32_t i=0; i<num_vectors; i++)
00072 {
00073 for (int32_t y=0; y<num_classes; y++)
00074 {
00075 if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6)
00076 {
00077 nz_idxs.push(i);
00078 break;
00079 }
00080 }
00081 }
00082 int32_t num_nz = nz_idxs.index();
00083 nz_idxs.reserve(num_nz);
00084 return SGVector<int32_t>(nz_idxs.begin,num_nz);
00085 }
00086
00087 SGMatrix<float64_t> CMulticlassLibLinear::obtain_regularizer_matrix() const
00088 {
00089 return SGMatrix<float64_t>();
00090 }
00091
00092 bool CMulticlassLibLinear::train_machine(CFeatures* data)
00093 {
00094 if (data)
00095 set_features((CDotFeatures*)data);
00096
00097 ASSERT(m_features);
00098 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS);
00099 ASSERT(m_multiclass_strategy);
00100
00101 int32_t num_vectors = m_features->get_num_vectors();
00102 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00103 int32_t bias_n = m_use_bias ? 1 : 0;
00104
00105 problem mc_problem;
00106 mc_problem.l = num_vectors;
00107 mc_problem.n = m_features->get_dim_feature_space() + bias_n;
00108 mc_problem.y = SG_MALLOC(float64_t, mc_problem.l);
00109 for (int32_t i=0; i<num_vectors; i++)
00110 mc_problem.y[i] = ((CMulticlassLabels*) m_labels)->get_int_label(i);
00111
00112 mc_problem.x = m_features;
00113 mc_problem.use_bias = m_use_bias;
00114
00115 SGMatrix<float64_t> w0 = obtain_regularizer_matrix();
00116
00117 if (!m_train_state)
00118 m_train_state = new mcsvm_state();
00119
00120 float64_t* C = SG_MALLOC(float64_t, num_vectors);
00121 for (int32_t i=0; i<num_vectors; i++)
00122 C[i] = m_C;
00123
00124 CSignal::clear_cancel();
00125
00126 Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,m_epsilon,
00127 m_max_iter,m_max_train_time,m_train_state);
00128 solver.solve();
00129
00130 m_machines->reset_array();
00131 for (int32_t i=0; i<num_classes; i++)
00132 {
00133 CLinearMachine* machine = new CLinearMachine();
00134 SGVector<float64_t> cw(mc_problem.n-bias_n);
00135
00136 for (int32_t j=0; j<mc_problem.n-bias_n; j++)
00137 cw[j] = m_train_state->w[j*num_classes+i];
00138
00139 machine->set_w(cw);
00140
00141 if (m_use_bias)
00142 machine->set_bias(m_train_state->w[(mc_problem.n-bias_n)*num_classes+i]);
00143
00144 m_machines->push_back(machine);
00145 }
00146
00147 if (!m_save_train_state)
00148 reset_train_state();
00149
00150 SG_FREE(C);
00151 SG_FREE(mc_problem.y);
00152
00153 return true;
00154 }