20 CMultitaskLogisticRegression::CMultitaskLogisticRegression() :
21 CMultitaskLinearMachine()
23 initialize_parameters();
24 register_parameters();
27 CMultitaskLogisticRegression::CMultitaskLogisticRegression(
28 float64_t z, CDotFeatures* train_features,
29 CBinaryLabels* train_labels, CTaskRelation* task_relation) :
30 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
32 initialize_parameters();
33 register_parameters();
37 CMultitaskLogisticRegression::~CMultitaskLogisticRegression()
41 void CMultitaskLogisticRegression::register_parameters()
51 void CMultitaskLogisticRegression::initialize_parameters()
56 set_regularization(0);
61 bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
63 if (data && (CDotFeatures*)data)
64 set_features((CDotFeatures*)data);
69 SGVector<
float64_t> y(m_labels->get_num_labels());
70 for (int32_t i=0; i<y.vlen; i++)
71 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
73 slep_options options = slep_options::default_options();
74 options.n_tasks = m_task_relation->get_num_tasks();
75 options.tasks_indices = m_task_relation->get_tasks_indices();
77 options.regularization = m_regularization;
78 options.termination = m_termination;
79 options.tolerance = m_tolerance;
80 options.max_iter = m_max_iter;
82 ETaskRelationType relation_type = m_task_relation->get_relation_type();
83 switch (relation_type)
88 options.mode = MULTITASK_GROUP;
89 options.loss = LOGISTIC;
90 slep_result_t result = slep_solver(features, y.vector, m_z, options);
97 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
98 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
99 options.ind_t = ind_t.vector;
100 options.n_nodes = ind_t.vlen / 3;
101 options.mode = MULTITASK_TREE;
102 options.loss = LOGISTIC;
103 slep_result_t result = slep_solver(features, y.vector, m_z, options);
104 m_tasks_w = result.w;
105 m_tasks_c = result.c;
109 SG_ERROR(
"Not supported task relation type\n")
111 SG_FREE(options.tasks_indices);
116 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
121 SGVector<
float64_t> y(m_labels->get_num_labels());
122 for (int32_t i=0; i<y.vlen; i++)
123 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
125 slep_options options = slep_options::default_options();
126 options.n_tasks = m_task_relation->get_num_tasks();
127 options.tasks_indices = tasks;
129 options.regularization = m_regularization;
130 options.termination = m_termination;
131 options.tolerance = m_tolerance;
132 options.max_iter = m_max_iter;
134 ETaskRelationType relation_type = m_task_relation->get_relation_type();
135 switch (relation_type)
140 options.mode = MULTITASK_GROUP;
141 options.loss = LOGISTIC;
142 slep_result_t result = slep_solver(features, y.vector, m_z, options);
143 m_tasks_w = result.w;
144 m_tasks_c = result.c;
149 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
150 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
151 options.ind_t = ind_t.vector;
152 options.n_nodes = ind_t.vlen / 3;
153 options.mode = MULTITASK_TREE;
154 options.loss = LOGISTIC;
155 slep_result_t result = slep_solver(features, y.vector, m_z, options);
156 m_tasks_w = result.w;
157 m_tasks_c = result.c;
161 SG_ERROR(
"Not supported task relation type\n")
166 float64_t CMultitaskLogisticRegression::apply_one(int32_t i)
168 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
171 return dot + m_tasks_c[m_current_task];
174 int32_t CMultitaskLogisticRegression::get_max_iter()
const
178 int32_t CMultitaskLogisticRegression::get_regularization()
const
180 return m_regularization;
182 int32_t CMultitaskLogisticRegression::get_termination()
const
184 return m_termination;
186 float64_t CMultitaskLogisticRegression::get_tolerance()
const
190 float64_t CMultitaskLogisticRegression::get_z()
const
194 float64_t CMultitaskLogisticRegression::get_q()
const
199 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter)
202 m_max_iter = max_iter;
204 void CMultitaskLogisticRegression::set_regularization(int32_t regularization)
206 ASSERT(regularization==0 || regularization==1)
207 m_regularization = regularization;
209 void CMultitaskLogisticRegression::set_termination(int32_t termination)
211 ASSERT(termination>=0 && termination<=4)
212 m_termination = termination;
214 void CMultitaskLogisticRegression::set_tolerance(
float64_t tolerance)
217 m_tolerance = tolerance;
219 void CMultitaskLogisticRegression::set_z(
float64_t z)
223 void CMultitaskLogisticRegression::set_q(
float64_t q)
230 #endif //USE_GPL_SHOGUN
Vector::Scalar dot(Vector a, Vector b)
all of classes and functions are contained in the shogun namespace