SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MultitaskLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
13 
14 namespace shogun
15 {
16 
19 {
20  initialize_parameters();
21  register_parameters();
22 }
23 
25  float64_t z, CDotFeatures* train_features,
26  CBinaryLabels* train_labels, CTaskRelation* task_relation) :
27  CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
28 {
29  initialize_parameters();
30  register_parameters();
31  set_z(z);
32 }
33 
35 {
36 }
37 
38 void CMultitaskLogisticRegression::register_parameters()
39 {
40  SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
41  SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
42  SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
43  SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
44  SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
45  SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
46 }
47 
48 void CMultitaskLogisticRegression::initialize_parameters()
49 {
50  set_z(0.0);
51  set_q(2.0);
52  set_termination(0);
54  set_tolerance(1e-3);
55  set_max_iter(1000);
56 }
57 
59 {
60  if (data && (CDotFeatures*)data)
61  set_features((CDotFeatures*)data);
62 
65 
67  for (int32_t i=0; i<y.vlen; i++)
68  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
69 
70  slep_options options = slep_options::default_options();
71  options.n_tasks = m_task_relation->get_num_tasks();
72  options.tasks_indices = m_task_relation->get_tasks_indices();
73  options.q = m_q;
74  options.regularization = m_regularization;
75  options.termination = m_termination;
76  options.tolerance = m_tolerance;
77  options.max_iter = m_max_iter;
78 
79  ETaskRelationType relation_type = m_task_relation->get_relation_type();
80  switch (relation_type)
81  {
82  case TASK_GROUP:
83  {
84  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
85  options.mode = MULTITASK_GROUP;
86  options.loss = LOGISTIC;
87  slep_result_t result = slep_solver(features, y.vector, m_z, options);
88  m_tasks_w = result.w;
89  m_tasks_c = result.c;
90  }
91  break;
92  case TASK_TREE:
93  {
94  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
95  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
96  options.ind_t = ind_t.vector;
97  options.n_nodes = ind_t.vlen / 3;
98  options.mode = MULTITASK_TREE;
99  options.loss = LOGISTIC;
100  slep_result_t result = slep_solver(features, y.vector, m_z, options);
101  m_tasks_w = result.w;
102  m_tasks_c = result.c;
103  }
104  break;
105  default:
106  SG_ERROR("Not supported task relation type\n");
107  }
108  for (int32_t i=0; i<options.n_tasks; i++)
109  options.tasks_indices[i].~SGVector<index_t>();
110  SG_FREE(options.tasks_indices);
111 
112  return true;
113 }
114 
116 {
117  ASSERT(features);
118  ASSERT(m_labels);
119 
121  for (int32_t i=0; i<y.vlen; i++)
122  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
123 
124  slep_options options = slep_options::default_options();
125  options.n_tasks = m_task_relation->get_num_tasks();
126  options.tasks_indices = tasks;
127  options.q = m_q;
128  options.regularization = m_regularization;
129  options.termination = m_termination;
130  options.tolerance = m_tolerance;
131  options.max_iter = m_max_iter;
132 
133  ETaskRelationType relation_type = m_task_relation->get_relation_type();
134  switch (relation_type)
135  {
136  case TASK_GROUP:
137  {
138  //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
139  options.mode = MULTITASK_GROUP;
140  options.loss = LOGISTIC;
141  slep_result_t result = slep_solver(features, y.vector, m_z, options);
142  m_tasks_w = result.w;
143  m_tasks_c = result.c;
144  }
145  break;
146  case TASK_TREE:
147  {
148  CTaskTree* task_tree = (CTaskTree*)m_task_relation;
149  SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
150  options.ind_t = ind_t.vector;
151  options.n_nodes = ind_t.vlen / 3;
152  options.mode = MULTITASK_TREE;
153  options.loss = LOGISTIC;
154  slep_result_t result = slep_solver(features, y.vector, m_z, options);
155  m_tasks_w = result.w;
156  m_tasks_c = result.c;
157  }
158  break;
159  default:
160  SG_ERROR("Not supported task relation type\n");
161  }
162  return true;
163 }
164 
166 {
168  //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
169  //return 2.0/(1.0+ep) - 1.0;
170  return dot + m_tasks_c[m_current_task];
171 }
172 
174 {
175  return m_max_iter;
176 }
178 {
179  return m_regularization;
180 }
182 {
183  return m_termination;
184 }
186 {
187  return m_tolerance;
188 }
190 {
191  return m_z;
192 }
194 {
195  return m_q;
196 }
197 
199 {
200  ASSERT(max_iter>=0);
201  m_max_iter = max_iter;
202 }
204 {
205  ASSERT(regularization==0 || regularization==1);
206  m_regularization = regularization;
207 }
209 {
210  ASSERT(termination>=0 && termination<=4);
211  m_termination = termination;
212 }
214 {
215  ASSERT(tolerance>0.0);
216  m_tolerance = tolerance;
217 }
219 {
220  m_z = z;
221 }
223 {
224  m_q = q;
225 }
226 
227 }

SHOGUN Machine Learning Toolbox - Documentation