Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/transfer/multitask/MultitaskLeastSquaresRegression.h>
00011 #include <shogun/transfer/multitask/TaskGroup.h>
00012 #include <shogun/transfer/multitask/TaskTree.h>
00013 #include <shogun/lib/slep/slep_solver.h>
00014 #include <shogun/lib/slep/slep_options.h>
00015
00016 namespace shogun
00017 {
00018
00019 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression() :
00020 CMultitaskLinearMachine()
00021 {
00022 initialize_parameters();
00023 register_parameters();
00024 }
00025
00026 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression(
00027 float64_t z, CDotFeatures* train_features,
00028 CRegressionLabels* train_labels, CTaskRelation* task_relation) :
00029 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
00030 {
00031 set_z(z);
00032 initialize_parameters();
00033 register_parameters();
00034 }
00035
00036 CMultitaskLeastSquaresRegression::~CMultitaskLeastSquaresRegression()
00037 {
00038 }
00039
00040 void CMultitaskLeastSquaresRegression::register_parameters()
00041 {
00042 SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
00043 SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
00044 SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
00045 SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
00046 SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
00047 SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
00048 }
00049
00050 void CMultitaskLeastSquaresRegression::initialize_parameters()
00051 {
00052 set_z(0.0);
00053 set_q(2.0);
00054 set_termination(0);
00055 set_regularization(0);
00056 set_tolerance(1e-3);
00057 set_max_iter(1000);
00058 }
00059
00060 bool CMultitaskLeastSquaresRegression::train_locked_implementation(SGVector<index_t>* tasks)
00061 {
00062 SG_NOTIMPLEMENTED;
00063 return false;
00064 }
00065
00066 float64_t CMultitaskLeastSquaresRegression::apply_one(int32_t i)
00067 {
00068 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
00069 return dot + m_tasks_c[m_current_task];
00070 }
00071
00072 int32_t CMultitaskLeastSquaresRegression::get_max_iter() const
00073 {
00074 return m_max_iter;
00075 }
00076 int32_t CMultitaskLeastSquaresRegression::get_regularization() const
00077 {
00078 return m_regularization;
00079 }
00080 int32_t CMultitaskLeastSquaresRegression::get_termination() const
00081 {
00082 return m_termination;
00083 }
00084 float64_t CMultitaskLeastSquaresRegression::get_tolerance() const
00085 {
00086 return m_tolerance;
00087 }
00088 float64_t CMultitaskLeastSquaresRegression::get_z() const
00089 {
00090 return m_z;
00091 }
00092 float64_t CMultitaskLeastSquaresRegression::get_q() const
00093 {
00094 return m_q;
00095 }
00096
00097 void CMultitaskLeastSquaresRegression::set_max_iter(int32_t max_iter)
00098 {
00099 ASSERT(max_iter>=0);
00100 m_max_iter = max_iter;
00101 }
00102 void CMultitaskLeastSquaresRegression::set_regularization(int32_t regularization)
00103 {
00104 ASSERT(regularization==0 || regularization==1);
00105 m_regularization = regularization;
00106 }
00107 void CMultitaskLeastSquaresRegression::set_termination(int32_t termination)
00108 {
00109 ASSERT(termination>=0 && termination<=4);
00110 m_termination = termination;
00111 }
00112 void CMultitaskLeastSquaresRegression::set_tolerance(float64_t tolerance)
00113 {
00114 ASSERT(tolerance>0.0);
00115 m_tolerance = tolerance;
00116 }
00117 void CMultitaskLeastSquaresRegression::set_z(float64_t z)
00118 {
00119 m_z = z;
00120 }
00121 void CMultitaskLeastSquaresRegression::set_q(float64_t q)
00122 {
00123 m_q = q;
00124 }
00125
00126 bool CMultitaskLeastSquaresRegression::train_machine(CFeatures* data)
00127 {
00128 if (data && (CDotFeatures*)data)
00129 set_features((CDotFeatures*)data);
00130
00131 ASSERT(features);
00132 ASSERT(m_labels);
00133
00134 SGVector<float64_t> y = ((CRegressionLabels*)m_labels)->get_labels();
00135
00136 slep_options options = slep_options::default_options();
00137 options.n_tasks = m_task_relation->get_num_tasks();
00138 options.tasks_indices = m_task_relation->get_tasks_indices();
00139 options.q = m_q;
00140 options.regularization = m_regularization;
00141 options.termination = m_termination;
00142 options.tolerance = m_tolerance;
00143 options.max_iter = m_max_iter;
00144
00145 ETaskRelationType relation_type = m_task_relation->get_relation_type();
00146 switch (relation_type)
00147 {
00148 case TASK_GROUP:
00149 {
00150
00151 options.mode = MULTITASK_GROUP;
00152 options.loss = LEAST_SQUARES;
00153 m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
00154 m_tasks_c = SGVector<float64_t>(options.n_tasks);
00155 m_tasks_c.zero();
00156 }
00157 break;
00158 case TASK_TREE:
00159 {
00160 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00161 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00162 options.ind_t = ind_t.vector;
00163 options.n_nodes = ind_t.vlen/3;
00164 options.mode = MULTITASK_TREE;
00165 options.loss = LEAST_SQUARES;
00166 m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
00167 m_tasks_c = SGVector<float64_t>(options.n_tasks);
00168 m_tasks_c.zero();
00169 }
00170 break;
00171 default:
00172 SG_ERROR("Not supported task relation type\n");
00173 }
00174
00175 for (int32_t i=0; i<options.n_tasks; i++)
00176 options.tasks_indices[i].~SGVector<index_t>();
00177 SG_FREE(options.tasks_indices);
00178
00179 return true;
00180 }
00181
00182 }