21 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression() :
22 CMultitaskLinearMachine()
24 initialize_parameters();
25 register_parameters();
28 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression(
29 float64_t z, CDotFeatures* train_features,
30 CRegressionLabels* train_labels, CTaskRelation* task_relation) :
31 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
34 initialize_parameters();
35 register_parameters();
38 CMultitaskLeastSquaresRegression::~CMultitaskLeastSquaresRegression()
42 void CMultitaskLeastSquaresRegression::register_parameters()
52 void CMultitaskLeastSquaresRegression::initialize_parameters()
57 set_regularization(0);
62 bool CMultitaskLeastSquaresRegression::train_locked_implementation(SGVector<index_t>* tasks)
68 float64_t CMultitaskLeastSquaresRegression::apply_one(int32_t i)
70 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
71 return dot + m_tasks_c[m_current_task];
74 int32_t CMultitaskLeastSquaresRegression::get_max_iter()
const
78 int32_t CMultitaskLeastSquaresRegression::get_regularization()
const
80 return m_regularization;
82 int32_t CMultitaskLeastSquaresRegression::get_termination()
const
86 float64_t CMultitaskLeastSquaresRegression::get_tolerance()
const
90 float64_t CMultitaskLeastSquaresRegression::get_z()
const
94 float64_t CMultitaskLeastSquaresRegression::get_q()
const
99 void CMultitaskLeastSquaresRegression::set_max_iter(int32_t max_iter)
102 m_max_iter = max_iter;
104 void CMultitaskLeastSquaresRegression::set_regularization(int32_t regularization)
106 ASSERT(regularization==0 || regularization==1)
107 m_regularization = regularization;
109 void CMultitaskLeastSquaresRegression::set_termination(int32_t termination)
111 ASSERT(termination>=0 && termination<=4)
112 m_termination = termination;
114 void CMultitaskLeastSquaresRegression::set_tolerance(
float64_t tolerance)
117 m_tolerance = tolerance;
119 void CMultitaskLeastSquaresRegression::set_z(
float64_t z)
123 void CMultitaskLeastSquaresRegression::set_q(
float64_t q)
128 bool CMultitaskLeastSquaresRegression::train_machine(CFeatures* data)
130 if (data && (CDotFeatures*)data)
131 set_features((CDotFeatures*)data);
136 SGVector<
float64_t> y = ((CRegressionLabels*)m_labels)->get_labels();
138 slep_options options = slep_options::default_options();
139 options.n_tasks = m_task_relation->get_num_tasks();
140 options.tasks_indices = m_task_relation->get_tasks_indices();
142 options.regularization = m_regularization;
143 options.termination = m_termination;
144 options.tolerance = m_tolerance;
145 options.max_iter = m_max_iter;
147 ETaskRelationType relation_type = m_task_relation->get_relation_type();
148 switch (relation_type)
153 options.mode = MULTITASK_GROUP;
154 options.loss = LEAST_SQUARES;
155 m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
156 m_tasks_c = SGVector<float64_t>(options.n_tasks);
162 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
163 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
164 options.ind_t = ind_t.vector;
165 options.n_nodes = ind_t.vlen/3;
166 options.mode = MULTITASK_TREE;
167 options.loss = LEAST_SQUARES;
168 m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
169 m_tasks_c = SGVector<float64_t>(options.n_tasks);
174 SG_ERROR(
"Not supported task relation type\n")
177 SG_FREE(options.tasks_indices);
184 #endif //USE_GPL_SHOGUN
Vector::Scalar dot(Vector a, Vector b)
#define SG_NOTIMPLEMENTED
all of classes and functions are contained in the shogun namespace