Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/multiclass/MulticlassTreeGuidedLogisticRegression.h>
00012 #ifdef HAVE_EIGEN3
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/labels/MulticlassLabels.h>
00016 #include <shogun/lib/slep/slep_mc_tree_lr.h>
00017
00018 using namespace shogun;
00019
00020 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression() :
00021 CLinearMulticlassMachine()
00022 {
00023 init_defaults();
00024 }
00025
00026 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs, CIndexBlockTree* tree) :
00027 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs)
00028 {
00029 init_defaults();
00030 set_z(z);
00031 set_index_tree(tree);
00032 }
00033
00034 void CMulticlassTreeGuidedLogisticRegression::init_defaults()
00035 {
00036 m_index_tree = NULL;
00037 set_z(0.1);
00038 set_epsilon(1e-2);
00039 set_max_iter(10000);
00040 }
00041
00042 void CMulticlassTreeGuidedLogisticRegression::register_parameters()
00043 {
00044 SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE);
00045 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
00046 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
00047 }
00048
00049 CMulticlassTreeGuidedLogisticRegression::~CMulticlassTreeGuidedLogisticRegression()
00050 {
00051 SG_UNREF(m_index_tree);
00052 }
00053
00054 bool CMulticlassTreeGuidedLogisticRegression::train_machine(CFeatures* data)
00055 {
00056 if (data)
00057 set_features((CDotFeatures*)data);
00058
00059 ASSERT(m_features);
00060 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS);
00061 ASSERT(m_multiclass_strategy);
00062 ASSERT(m_index_tree);
00063
00064 int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes();
00065 int32_t n_feats = m_features->get_dim_feature_space();
00066
00067 slep_options options = slep_options::default_options();
00068 if (m_machines->get_num_elements()!=0)
00069 {
00070 SGMatrix<float64_t> all_w_old(n_feats, n_classes);
00071 SGVector<float64_t> all_c_old(n_classes);
00072 for (int32_t i=0; i<n_classes; i++)
00073 {
00074 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
00075 SGVector<float64_t> w = machine->get_w();
00076 for (int32_t j=0; j<n_feats; j++)
00077 all_w_old(j,i) = w[j];
00078 all_c_old[i] = machine->get_bias();
00079 SG_UNREF(machine);
00080 }
00081 options.last_result = new slep_result_t(all_w_old,all_c_old);
00082 m_machines->reset_array();
00083 }
00084 if (m_index_tree->is_general())
00085 {
00086 SGVector<float64_t> G = m_index_tree->get_SLEP_G();
00087 options.G = G.vector;
00088 }
00089 SGVector<float64_t> ind_t = m_index_tree->get_SLEP_ind_t();
00090 options.ind_t = ind_t.vector;
00091 options.n_nodes = ind_t.size()/3;
00092 options.tolerance = m_epsilon;
00093 options.max_iter = m_max_iter;
00094 slep_result_t result = slep_mc_tree_lr(m_features,(CMulticlassLabels*)m_labels,m_z,options);
00095
00096 SGMatrix<float64_t> all_w = result.w;
00097 SGVector<float64_t> all_c = result.c;
00098 for (int32_t i=0; i<n_classes; i++)
00099 {
00100 SGVector<float64_t> w(n_feats);
00101 for (int32_t j=0; j<n_feats; j++)
00102 w[j] = all_w(j,i);
00103 float64_t c = all_c[i];
00104 CLinearMachine* machine = new CLinearMachine();
00105 machine->set_w(w);
00106 machine->set_bias(c);
00107 m_machines->push_back(machine);
00108 }
00109 return true;
00110 }
00111 #endif