Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/transfer/multitask/MultitaskLogisticRegression.h>
00011 #include <shogun/lib/slep/slep_solver.h>
00012 #include <shogun/lib/slep/slep_options.h>
00013
00014 namespace shogun
00015 {
00016
00017 CMultitaskLogisticRegression::CMultitaskLogisticRegression() :
00018 CMultitaskLinearMachine()
00019 {
00020 initialize_parameters();
00021 register_parameters();
00022 }
00023
00024 CMultitaskLogisticRegression::CMultitaskLogisticRegression(
00025 float64_t z, CDotFeatures* train_features,
00026 CBinaryLabels* train_labels, CTaskRelation* task_relation) :
00027 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
00028 {
00029 initialize_parameters();
00030 register_parameters();
00031 set_z(z);
00032 }
00033
00034 CMultitaskLogisticRegression::~CMultitaskLogisticRegression()
00035 {
00036 }
00037
00038 void CMultitaskLogisticRegression::register_parameters()
00039 {
00040 SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
00041 SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
00042 SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
00043 SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
00044 SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
00045 SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
00046 }
00047
00048 void CMultitaskLogisticRegression::initialize_parameters()
00049 {
00050 set_z(0.0);
00051 set_q(2.0);
00052 set_termination(0);
00053 set_regularization(0);
00054 set_tolerance(1e-3);
00055 set_max_iter(1000);
00056 }
00057
00058 bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
00059 {
00060 if (data && (CDotFeatures*)data)
00061 set_features((CDotFeatures*)data);
00062
00063 ASSERT(features);
00064 ASSERT(m_labels);
00065
00066 SGVector<float64_t> y(m_labels->get_num_labels());
00067 for (int32_t i=0; i<y.vlen; i++)
00068 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00069
00070 slep_options options = slep_options::default_options();
00071 options.n_tasks = m_task_relation->get_num_tasks();
00072 options.tasks_indices = m_task_relation->get_tasks_indices();
00073 options.q = m_q;
00074 options.regularization = m_regularization;
00075 options.termination = m_termination;
00076 options.tolerance = m_tolerance;
00077 options.max_iter = m_max_iter;
00078
00079 ETaskRelationType relation_type = m_task_relation->get_relation_type();
00080 switch (relation_type)
00081 {
00082 case TASK_GROUP:
00083 {
00084
00085 options.mode = MULTITASK_GROUP;
00086 options.loss = LOGISTIC;
00087 slep_result_t result = slep_solver(features, y.vector, m_z, options);
00088 m_tasks_w = result.w;
00089 m_tasks_c = result.c;
00090 }
00091 break;
00092 case TASK_TREE:
00093 {
00094 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00095 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00096 options.ind_t = ind_t.vector;
00097 options.n_nodes = ind_t.vlen / 3;
00098 options.mode = MULTITASK_TREE;
00099 options.loss = LOGISTIC;
00100 slep_result_t result = slep_solver(features, y.vector, m_z, options);
00101 m_tasks_w = result.w;
00102 m_tasks_c = result.c;
00103 }
00104 break;
00105 default:
00106 SG_ERROR("Not supported task relation type\n");
00107 }
00108 for (int32_t i=0; i<options.n_tasks; i++)
00109 options.tasks_indices[i].~SGVector<index_t>();
00110 SG_FREE(options.tasks_indices);
00111
00112 return true;
00113 }
00114
00115 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
00116 {
00117 ASSERT(features);
00118 ASSERT(m_labels);
00119
00120 SGVector<float64_t> y(m_labels->get_num_labels());
00121 for (int32_t i=0; i<y.vlen; i++)
00122 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00123
00124 slep_options options = slep_options::default_options();
00125 options.n_tasks = m_task_relation->get_num_tasks();
00126 options.tasks_indices = tasks;
00127 options.q = m_q;
00128 options.regularization = m_regularization;
00129 options.termination = m_termination;
00130 options.tolerance = m_tolerance;
00131 options.max_iter = m_max_iter;
00132
00133 ETaskRelationType relation_type = m_task_relation->get_relation_type();
00134 switch (relation_type)
00135 {
00136 case TASK_GROUP:
00137 {
00138
00139 options.mode = MULTITASK_GROUP;
00140 options.loss = LOGISTIC;
00141 slep_result_t result = slep_solver(features, y.vector, m_z, options);
00142 m_tasks_w = result.w;
00143 m_tasks_c = result.c;
00144 }
00145 break;
00146 case TASK_TREE:
00147 {
00148 CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00149 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00150 options.ind_t = ind_t.vector;
00151 options.n_nodes = ind_t.vlen / 3;
00152 options.mode = MULTITASK_TREE;
00153 options.loss = LOGISTIC;
00154 slep_result_t result = slep_solver(features, y.vector, m_z, options);
00155 m_tasks_w = result.w;
00156 m_tasks_c = result.c;
00157 }
00158 break;
00159 default:
00160 SG_ERROR("Not supported task relation type\n");
00161 }
00162 return true;
00163 }
00164
00165 float64_t CMultitaskLogisticRegression::apply_one(int32_t i)
00166 {
00167 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
00168
00169
00170 return dot + m_tasks_c[m_current_task];
00171 }
00172
00173 int32_t CMultitaskLogisticRegression::get_max_iter() const
00174 {
00175 return m_max_iter;
00176 }
00177 int32_t CMultitaskLogisticRegression::get_regularization() const
00178 {
00179 return m_regularization;
00180 }
00181 int32_t CMultitaskLogisticRegression::get_termination() const
00182 {
00183 return m_termination;
00184 }
00185 float64_t CMultitaskLogisticRegression::get_tolerance() const
00186 {
00187 return m_tolerance;
00188 }
00189 float64_t CMultitaskLogisticRegression::get_z() const
00190 {
00191 return m_z;
00192 }
00193 float64_t CMultitaskLogisticRegression::get_q() const
00194 {
00195 return m_q;
00196 }
00197
00198 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter)
00199 {
00200 ASSERT(max_iter>=0);
00201 m_max_iter = max_iter;
00202 }
00203 void CMultitaskLogisticRegression::set_regularization(int32_t regularization)
00204 {
00205 ASSERT(regularization==0 || regularization==1);
00206 m_regularization = regularization;
00207 }
00208 void CMultitaskLogisticRegression::set_termination(int32_t termination)
00209 {
00210 ASSERT(termination>=0 && termination<=4);
00211 m_termination = termination;
00212 }
00213 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance)
00214 {
00215 ASSERT(tolerance>0.0);
00216 m_tolerance = tolerance;
00217 }
00218 void CMultitaskLogisticRegression::set_z(float64_t z)
00219 {
00220 m_z = z;
00221 }
00222 void CMultitaskLogisticRegression::set_q(float64_t q)
00223 {
00224 m_q = q;
00225 }
00226
00227 }