SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
MulticlassTreeGuidedLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Sergey Lisitsyn
8  * Copyright (C) 2012 Sergey Lisitsyn
9  */
10 
11 
13 #ifdef USE_GPL_SHOGUN
18 
19 using namespace shogun;
20 
21 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression() :
23 {
24  init_defaults();
25 }
26 
27 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs, CIndexBlockTree* tree) :
29 {
30  init_defaults();
31  set_z(z);
32  set_index_tree(tree);
33 }
34 
35 void CMulticlassTreeGuidedLogisticRegression::init_defaults()
36 {
37  m_index_tree = NULL;
38  set_z(0.1);
39  set_epsilon(1e-2);
40  set_max_iter(10000);
41 }
42 
43 void CMulticlassTreeGuidedLogisticRegression::register_parameters()
44 {
45  SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE);
46  SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
47  SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
48 }
49 
50 CMulticlassTreeGuidedLogisticRegression::~CMulticlassTreeGuidedLogisticRegression()
51 {
52  SG_UNREF(m_index_tree);
53 }
54 
55 bool CMulticlassTreeGuidedLogisticRegression::train_machine(CFeatures* data)
56 {
57  if (data)
58  set_features((CDotFeatures*)data);
59 
60  ASSERT(m_features)
61  ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS)
62  ASSERT(m_multiclass_strategy)
63  ASSERT(m_index_tree)
64 
65  int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes();
66  int32_t n_feats = m_features->get_dim_feature_space();
67 
68  slep_options options = slep_options::default_options();
69  if (m_machines->get_num_elements()!=0)
70  {
71  SGMatrix<float64_t> all_w_old(n_feats, n_classes);
72  SGVector<float64_t> all_c_old(n_classes);
73  for (int32_t i=0; i<n_classes; i++)
74  {
75  CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
76  SGVector<float64_t> w = machine->get_w();
77  for (int32_t j=0; j<n_feats; j++)
78  all_w_old(j,i) = w[j];
79  all_c_old[i] = machine->get_bias();
80  SG_UNREF(machine);
81  }
82  options.last_result = new slep_result_t(all_w_old,all_c_old);
83  m_machines->reset_array();
84  }
85  if (m_index_tree->is_general())
86  {
87  SGVector<float64_t> G = m_index_tree->get_SLEP_G();
88  options.G = G.vector;
89  }
90  SGVector<float64_t> ind_t = m_index_tree->get_SLEP_ind_t();
91  options.ind_t = ind_t.vector;
92  options.n_nodes = ind_t.size()/3;
93  options.tolerance = m_epsilon;
94  options.max_iter = m_max_iter;
95  slep_result_t result = slep_mc_tree_lr(m_features,(CMulticlassLabels*)m_labels,m_z,options);
96 
97  SGMatrix<float64_t> all_w = result.w;
98  SGVector<float64_t> all_c = result.c;
99  for (int32_t i=0; i<n_classes; i++)
100  {
101  SGVector<float64_t> w(n_feats);
102  for (int32_t j=0; j<n_feats; j++)
103  w[j] = all_w(j,i);
104  float64_t c = all_c[i];
105  CLinearMachine* machine = new CLinearMachine();
106  machine->set_w(w);
107  machine->set_bias(c);
108  m_machines->push_back(machine);
109  }
110  return true;
111 }
112 #endif //USE_GPL_SHOGUN
virtual void set_w(const SGVector< float64_t > src_w)
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
class IndexBlockTree used to represent tree guided feature relation.
multi-class labels 0,1,...
Definition: LabelTypes.h:20
Features that support dot products among other operations.
Definition: DotFeatures.h:44
Multiclass Labels for multi-class classification.
int32_t size() const
Definition: SGVector.h:113
generic linear multiclass machine
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
Definition: LinearMachine.h:63
virtual SGVector< float64_t > get_w() const
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual float64_t get_bias()
virtual void set_bias(float64_t b)
void set_epsilon(float *begin, float max)
Definition: JLCoverTree.h:513
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...
#define SG_ADD(...)
Definition: SGObject.h:81

SHOGUN Machine Learning Toolbox - Documentation