SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
MultitaskLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
10 
12 #ifdef USE_GPL_SHOGUN
15 #include <vector>
16 
17 namespace shogun
18 {
19 
20 CMultitaskLogisticRegression::CMultitaskLogisticRegression() :
21  CMultitaskLinearMachine()
22 {
23  initialize_parameters();
24  register_parameters();
25 }
26 
27 CMultitaskLogisticRegression::CMultitaskLogisticRegression(
28  float64_t z, CDotFeatures* train_features,
29  CBinaryLabels* train_labels, CTaskRelation* task_relation) :
30  CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
31 {
32  initialize_parameters();
33  register_parameters();
34  set_z(z);
35 }
36 
37 CMultitaskLogisticRegression::~CMultitaskLogisticRegression()
38 {
39 }
40 
41 void CMultitaskLogisticRegression::register_parameters()
42 {
43  SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
44  SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
45  SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
46  SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
47  SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
48  SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
49 }
50 
51 void CMultitaskLogisticRegression::initialize_parameters()
52 {
53  set_z(0.0);
54  set_q(2.0);
55  set_termination(0);
56  set_regularization(0);
57  set_tolerance(1e-3);
58  set_max_iter(1000);
59 }
60 
61 bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
62 {
63  if (data && (CDotFeatures*)data)
64  set_features((CDotFeatures*)data);
65 
66  ASSERT(features)
67  ASSERT(m_labels)
68 
69  SGVector<float64_t> y(m_labels->get_num_labels());
70  for (int32_t i=0; i<y.vlen; i++)
71  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
72 
73  slep_options options = slep_options::default_options();
74  options.n_tasks = m_task_relation->get_num_tasks();
75  options.tasks_indices = m_task_relation->get_tasks_indices();
76  options.q = m_q;
77  options.regularization = m_regularization;
78  options.termination = m_termination;
79  options.tolerance = m_tolerance;
80  options.max_iter = m_max_iter;
81 
82  ETaskRelationType relation_type = m_task_relation->get_relation_type();
83  switch (relation_type)
84  {
85  case TASK_GROUP:
86  {
87  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
88  options.mode = MULTITASK_GROUP;
89  options.loss = LOGISTIC;
90  slep_result_t result = slep_solver(features, y.vector, m_z, options);
91  m_tasks_w = result.w;
92  m_tasks_c = result.c;
93  }
94  break;
95  case TASK_TREE:
96  {
97  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
98  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
99  options.ind_t = ind_t.vector;
100  options.n_nodes = ind_t.vlen / 3;
101  options.mode = MULTITASK_TREE;
102  options.loss = LOGISTIC;
103  slep_result_t result = slep_solver(features, y.vector, m_z, options);
104  m_tasks_w = result.w;
105  m_tasks_c = result.c;
106  }
107  break;
108  default:
109  SG_ERROR("Not supported task relation type\n")
110  }
111  SG_FREE(options.tasks_indices);
112 
113  return true;
114 }
115 
116 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
117 {
118  ASSERT(features)
119  ASSERT(m_labels)
120 
121  SGVector<float64_t> y(m_labels->get_num_labels());
122  for (int32_t i=0; i<y.vlen; i++)
123  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
124 
125  slep_options options = slep_options::default_options();
126  options.n_tasks = m_task_relation->get_num_tasks();
127  options.tasks_indices = tasks;
128  options.q = m_q;
129  options.regularization = m_regularization;
130  options.termination = m_termination;
131  options.tolerance = m_tolerance;
132  options.max_iter = m_max_iter;
133 
134  ETaskRelationType relation_type = m_task_relation->get_relation_type();
135  switch (relation_type)
136  {
137  case TASK_GROUP:
138  {
139  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
140  options.mode = MULTITASK_GROUP;
141  options.loss = LOGISTIC;
142  slep_result_t result = slep_solver(features, y.vector, m_z, options);
143  m_tasks_w = result.w;
144  m_tasks_c = result.c;
145  }
146  break;
147  case TASK_TREE:
148  {
149  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
150  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
151  options.ind_t = ind_t.vector;
152  options.n_nodes = ind_t.vlen / 3;
153  options.mode = MULTITASK_TREE;
154  options.loss = LOGISTIC;
155  slep_result_t result = slep_solver(features, y.vector, m_z, options);
156  m_tasks_w = result.w;
157  m_tasks_c = result.c;
158  }
159  break;
160  default:
161  SG_ERROR("Not supported task relation type\n")
162  }
163  return true;
164 }
165 
166 float64_t CMultitaskLogisticRegression::apply_one(int32_t i)
167 {
168  float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
169  //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
170  //return 2.0/(1.0+ep) - 1.0;
171  return dot + m_tasks_c[m_current_task];
172 }
173 
174 int32_t CMultitaskLogisticRegression::get_max_iter() const
175 {
176  return m_max_iter;
177 }
178 int32_t CMultitaskLogisticRegression::get_regularization() const
179 {
180  return m_regularization;
181 }
182 int32_t CMultitaskLogisticRegression::get_termination() const
183 {
184  return m_termination;
185 }
186 float64_t CMultitaskLogisticRegression::get_tolerance() const
187 {
188  return m_tolerance;
189 }
190 float64_t CMultitaskLogisticRegression::get_z() const
191 {
192  return m_z;
193 }
194 float64_t CMultitaskLogisticRegression::get_q() const
195 {
196  return m_q;
197 }
198 
199 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter)
200 {
201  ASSERT(max_iter>=0)
202  m_max_iter = max_iter;
203 }
204 void CMultitaskLogisticRegression::set_regularization(int32_t regularization)
205 {
206  ASSERT(regularization==0 || regularization==1)
207  m_regularization = regularization;
208 }
209 void CMultitaskLogisticRegression::set_termination(int32_t termination)
210 {
211  ASSERT(termination>=0 && termination<=4)
212  m_termination = termination;
213 }
214 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance)
215 {
216  ASSERT(tolerance>0.0)
217  m_tolerance = tolerance;
218 }
219 void CMultitaskLogisticRegression::set_z(float64_t z)
220 {
221  m_z = z;
222 }
223 void CMultitaskLogisticRegression::set_q(float64_t q)
224 {
225  m_q = q;
226 }
227 
228 }
229 
230 #endif //USE_GPL_SHOGUN
Vector::Scalar dot(Vector a, Vector b)
Definition: Redux.h:58
#define SG_ERROR(...)
Definition: SGIO.h:129
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_ADD(...)
Definition: SGObject.h:84

SHOGUN Machine Learning Toolbox - Documentation