Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/MulticlassLogisticRegression.h>
00012 #ifdef HAVE_EIGEN3
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/io/SGIO.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/labels/MulticlassLabels.h>
00017 #include <shogun/lib/slep/slep_mc_plain_lr.h>
00018
00019 using namespace shogun;
00020
00021 CMulticlassLogisticRegression::CMulticlassLogisticRegression() :
00022 CLinearMulticlassMachine()
00023 {
00024 init_defaults();
00025 }
00026
00027 CMulticlassLogisticRegression::CMulticlassLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs) :
00028 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs)
00029 {
00030 init_defaults();
00031 set_z(z);
00032 }
00033
00034 void CMulticlassLogisticRegression::init_defaults()
00035 {
00036 set_z(0.1);
00037 set_epsilon(1e-2);
00038 set_max_iter(10000);
00039 }
00040
00041 void CMulticlassLogisticRegression::register_parameters()
00042 {
00043 SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE);
00044 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
00045 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
00046 }
00047
00048 CMulticlassLogisticRegression::~CMulticlassLogisticRegression()
00049 {
00050 }
00051
00052 bool CMulticlassLogisticRegression::train_machine(CFeatures* data)
00053 {
00054 SG_UNSTABLE("MulticlassLogisticRegression","\n");
00055 if (data)
00056 set_features((CDotFeatures*)data);
00057
00058 ASSERT(m_features);
00059 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS);
00060 ASSERT(m_multiclass_strategy);
00061
00062 int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes();
00063 int32_t n_feats = m_features->get_dim_feature_space();
00064
00065 slep_options options = slep_options::default_options();
00066 if (m_machines->get_num_elements()!=0)
00067 {
00068 SGMatrix<float64_t> all_w_old(n_feats, n_classes);
00069 SGVector<float64_t> all_c_old(n_classes);
00070 for (int32_t i=0; i<n_classes; i++)
00071 {
00072 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
00073 SGVector<float64_t> w = machine->get_w();
00074 for (int32_t j=0; j<n_feats; j++)
00075 all_w_old(j,i) = w[j];
00076 all_c_old[i] = machine->get_bias();
00077 SG_UNREF(machine);
00078 }
00079 options.last_result = new slep_result_t(all_w_old,all_c_old);
00080 m_machines->reset_array();
00081 }
00082 options.tolerance = m_epsilon;
00083 options.max_iter = m_max_iter;
00084 slep_result_t result = slep_mc_plain_lr(m_features,(CMulticlassLabels*)m_labels,m_z,options);
00085
00086 SGMatrix<float64_t> all_w = result.w;
00087 SGVector<float64_t> all_c = result.c;
00088 for (int32_t i=0; i<n_classes; i++)
00089 {
00090 SGVector<float64_t> w(n_feats);
00091 for (int32_t j=0; j<n_feats; j++)
00092 w[j] = all_w(j,i);
00093 float64_t c = all_c[i];
00094 CLinearMachine* machine = new CLinearMachine();
00095 machine->set_w(w);
00096 machine->set_bias(c);
00097 m_machines->push_back(machine);
00098 }
00099 return true;
00100 }
00101 #endif