24 CMultitaskLinearMachine::CMultitaskLinearMachine() :
25 CLinearMachine(), m_current_task(0),
28 register_parameters();
31 CMultitaskLinearMachine::CMultitaskLinearMachine(
32 CDotFeatures* train_features,
33 CLabels* train_labels, CTaskRelation* task_relation) :
34 CLinearMachine(), m_current_task(0), m_task_relation(NULL)
36 set_features(train_features);
37 set_labels(train_labels);
38 set_task_relation(task_relation);
39 register_parameters();
42 CMultitaskLinearMachine::~CMultitaskLinearMachine()
47 void CMultitaskLinearMachine::register_parameters()
52 int32_t CMultitaskLinearMachine::get_current_task()
const
54 return m_current_task;
57 void CMultitaskLinearMachine::set_current_task(int32_t task)
60 ASSERT(task<m_tasks_w.num_cols)
61 m_current_task = task;
64 CTaskRelation* CMultitaskLinearMachine::get_task_relation()
const
67 return m_task_relation;
70 void CMultitaskLinearMachine::set_task_relation(CTaskRelation* task_relation)
74 m_task_relation = task_relation;
77 bool CMultitaskLinearMachine::train_machine(CFeatures* data)
83 void CMultitaskLinearMachine::post_lock(CLabels* labels, CFeatures* features_)
85 set_features((CDotFeatures*)features_);
86 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
87 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
89 m_tasks_indices.clear();
90 for (int32_t i=0; i<n_tasks; i++)
92 std::set<index_t> indices_set;
93 SGVector<index_t> task_indices = tasks_indices[i];
94 for (int32_t j=0; j<task_indices.vlen; j++)
95 indices_set.insert(task_indices[j]);
97 m_tasks_indices.push_back(indices_set);
100 SG_FREE(tasks_indices);
103 bool CMultitaskLinearMachine::train_locked(SGVector<index_t> indices)
105 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
106 ASSERT((
int)m_tasks_indices.size()==n_tasks)
107 vector< vector<index_t> > cutted_task_indices;
108 for (int32_t i=0; i<n_tasks; i++)
109 cutted_task_indices.push_back(vector<index_t>());
110 for (int32_t i=0; i<indices.vlen; i++)
112 for (int32_t j=0; j<n_tasks; j++)
114 if (m_tasks_indices[j].count(indices[i]))
116 cutted_task_indices[j].push_back(indices[i]);
121 SGVector<index_t>* tasks = SG_MALLOC(SGVector<index_t>, n_tasks);
122 for (int32_t i=0; i<n_tasks; i++)
124 tasks[i]=SGVector<index_t>(cutted_task_indices[i].size());
125 for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++)
126 tasks[i][j] = cutted_task_indices[i][j];
129 bool res = train_locked_implementation(tasks);
134 bool CMultitaskLinearMachine::train_locked_implementation(SGVector<index_t>* tasks)
140 CBinaryLabels* CMultitaskLinearMachine::apply_locked_binary(SGVector<index_t> indices)
142 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
143 SGVector<float64_t> result(indices.vlen);
145 for (int32_t i=0; i<indices.vlen; i++)
147 for (int32_t j=0; j<n_tasks; j++)
149 if (m_tasks_indices[j].count(indices[i]))
152 result[i] = apply_one(indices[i]);
157 return new CBinaryLabels(result);
160 float64_t CMultitaskLinearMachine::apply_one(int32_t i)
166 SGVector<float64_t> CMultitaskLinearMachine::apply_get_outputs(CFeatures* data)
170 if (!data->has_property(
FP_DOT))
171 SG_ERROR(
"Specified features are not of type CDotFeatures\n")
173 set_features((CDotFeatures*) data);
177 return SGVector<float64_t>();
179 int32_t num=features->get_num_vectors();
182 for (int32_t i=0; i<num; i++)
183 out[i] = apply_one(i);
185 return SGVector<float64_t>(out,num);
188 SGVector<float64_t> CMultitaskLinearMachine::get_w()
const
190 SGVector<float64_t> w_(m_tasks_w.num_rows);
191 for (int32_t i=0; i<w_.vlen; i++)
192 w_[i] = m_tasks_w(i,m_current_task);
196 void CMultitaskLinearMachine::set_w(
const SGVector<float64_t> src_w)
198 for (int32_t i=0; i<m_tasks_w.num_rows; i++)
199 m_tasks_w(i,m_current_task) = src_w[i];
202 void CMultitaskLinearMachine::set_bias(float64_t b)
204 m_tasks_c[m_current_task] = b;
207 float64_t CMultitaskLinearMachine::get_bias()
209 return m_tasks_c[m_current_task];
212 SGVector<index_t>* CMultitaskLinearMachine::get_subset_tasks_indices()
214 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
215 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
217 CSubsetStack* sstack = features->get_subset_stack();
218 map<index_t,index_t> subset_inv_map = map<index_t,index_t>();
219 for (int32_t i=0; i<sstack->get_size(); i++)
220 subset_inv_map[sstack->subset_idx_conversion(i)] = i;
225 SGVector<index_t>* subset_tasks_indices = SG_MALLOC(SGVector<index_t>, n_tasks);
226 for (int32_t i=0; i<n_tasks; i++)
228 SGVector<index_t> task = tasks_indices[i];
230 vector<index_t> cutted = vector<index_t>();
231 for (int32_t j=0; j<task.vlen; j++)
233 if (subset_inv_map.count(task[j]))
234 cutted.push_back(subset_inv_map[task[j]]);
236 SGVector<index_t> cutted_task(cutted.size());
237 for (int32_t j=0; j<cutted_task.vlen; j++)
238 cutted_task[j] = cutted[j];
240 subset_tasks_indices[i] = cutted_task;
242 SG_FREE(tasks_indices);
244 return subset_tasks_indices;
249 #endif //USE_GPL_SHOGUN
#define SG_NOTIMPLEMENTED
all of classes and functions are contained in the shogun namespace