SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MultitaskLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
13 
14 namespace shogun
15 {
16 
19 {
20  initialize_parameters();
21  register_parameters();
22 }
23 
25  float64_t z, CDotFeatures* train_features,
26  CBinaryLabels* train_labels, CTaskRelation* task_relation) :
27  CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
28 {
29  initialize_parameters();
30  register_parameters();
31  set_z(z);
32 }
33 
35 {
36 }
37 
38 void CMultitaskLogisticRegression::register_parameters()
39 {
40  SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
41  SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
42  SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
43  SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
44  SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
45  SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
46 }
47 
48 void CMultitaskLogisticRegression::initialize_parameters()
49 {
50  set_z(0.0);
51  set_q(2.0);
52  set_termination(0);
54  set_tolerance(1e-3);
55  set_max_iter(1000);
56 }
57 
59 {
60  if (data && (CDotFeatures*)data)
61  set_features((CDotFeatures*)data);
62 
65 
67  for (int32_t i=0; i<y.vlen; i++)
68  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
69 
70  slep_options options = slep_options::default_options();
71  options.n_tasks = m_task_relation->get_num_tasks();
72  options.tasks_indices = m_task_relation->get_tasks_indices();
73  options.q = m_q;
74  options.regularization = m_regularization;
75  options.termination = m_termination;
76  options.tolerance = m_tolerance;
77  options.max_iter = m_max_iter;
78 
79  ETaskRelationType relation_type = m_task_relation->get_relation_type();
80  switch (relation_type)
81  {
82  case TASK_GROUP:
83  {
84  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
85  options.mode = MULTITASK_GROUP;
86  options.loss = LOGISTIC;
87  slep_result_t result = slep_solver(features, y.vector, m_z, options);
88  m_tasks_w = result.w;
89  m_tasks_c = result.c;
90  }
91  break;
92  case TASK_TREE:
93  {
94  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
95  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
96  options.ind_t = ind_t.vector;
97  options.n_nodes = ind_t.vlen / 3;
98  options.mode = MULTITASK_TREE;
99  options.loss = LOGISTIC;
100  slep_result_t result = slep_solver(features, y.vector, m_z, options);
101  m_tasks_w = result.w;
102  m_tasks_c = result.c;
103  }
104  break;
105  default:
106  SG_ERROR("Not supported task relation type\n")
107  }
108  SG_FREE(options.tasks_indices);
109 
110  return true;
111 }
112 
114 {
117 
119  for (int32_t i=0; i<y.vlen; i++)
120  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
121 
122  slep_options options = slep_options::default_options();
123  options.n_tasks = m_task_relation->get_num_tasks();
124  options.tasks_indices = tasks;
125  options.q = m_q;
126  options.regularization = m_regularization;
127  options.termination = m_termination;
128  options.tolerance = m_tolerance;
129  options.max_iter = m_max_iter;
130 
131  ETaskRelationType relation_type = m_task_relation->get_relation_type();
132  switch (relation_type)
133  {
134  case TASK_GROUP:
135  {
136  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
137  options.mode = MULTITASK_GROUP;
138  options.loss = LOGISTIC;
139  slep_result_t result = slep_solver(features, y.vector, m_z, options);
140  m_tasks_w = result.w;
141  m_tasks_c = result.c;
142  }
143  break;
144  case TASK_TREE:
145  {
146  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
147  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
148  options.ind_t = ind_t.vector;
149  options.n_nodes = ind_t.vlen / 3;
150  options.mode = MULTITASK_TREE;
151  options.loss = LOGISTIC;
152  slep_result_t result = slep_solver(features, y.vector, m_z, options);
153  m_tasks_w = result.w;
154  m_tasks_c = result.c;
155  }
156  break;
157  default:
158  SG_ERROR("Not supported task relation type\n")
159  }
160  return true;
161 }
162 
164 {
166  //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
167  //return 2.0/(1.0+ep) - 1.0;
168  return dot + m_tasks_c[m_current_task];
169 }
170 
172 {
173  return m_max_iter;
174 }
176 {
177  return m_regularization;
178 }
180 {
181  return m_termination;
182 }
184 {
185  return m_tolerance;
186 }
188 {
189  return m_z;
190 }
192 {
193  return m_q;
194 }
195 
197 {
198  ASSERT(max_iter>=0)
199  m_max_iter = max_iter;
200 }
202 {
203  ASSERT(regularization==0 || regularization==1)
204  m_regularization = regularization;
205 }
207 {
208  ASSERT(termination>=0 && termination<=4)
209  m_termination = termination;
210 }
212 {
213  ASSERT(tolerance>0.0)
214  m_tolerance = tolerance;
215 }
217 {
218  m_z = z;
219 }
221 {
222  m_q = q;
223 }
224 
225 }

SHOGUN Machine Learning Toolbox - Documentation