MultitaskLogisticRegression.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskLogisticRegression.h>
00011 #include <shogun/lib/slep/slep_solver.h>
00012 #include <shogun/lib/slep/slep_options.h>
00013 
00014 namespace shogun
00015 {
00016 
00017 CMultitaskLogisticRegression::CMultitaskLogisticRegression() :
00018     CMultitaskLinearMachine()
00019 {
00020     initialize_parameters();
00021     register_parameters();
00022 }
00023 
00024 CMultitaskLogisticRegression::CMultitaskLogisticRegression(
00025      float64_t z, CDotFeatures* train_features, 
00026      CBinaryLabels* train_labels, CTaskRelation* task_relation) :
00027     CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
00028 {
00029     initialize_parameters();
00030     register_parameters();
00031     set_z(z);
00032 }
00033 
00034 CMultitaskLogisticRegression::~CMultitaskLogisticRegression()
00035 {
00036 }
00037 
00038 void CMultitaskLogisticRegression::register_parameters()
00039 {
00040     SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
00041     SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
00042     SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
00043     SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
00044     SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
00045     SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
00046 }
00047 
00048 void CMultitaskLogisticRegression::initialize_parameters()
00049 {
00050     set_z(0.0);
00051     set_q(2.0);
00052     set_termination(0);
00053     set_regularization(0);
00054     set_tolerance(1e-3);
00055     set_max_iter(1000);
00056 }
00057 
00058 bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
00059 {
00060     if (data && (CDotFeatures*)data)
00061         set_features((CDotFeatures*)data);
00062 
00063     ASSERT(features);
00064     ASSERT(m_labels);
00065     
00066     SGVector<float64_t> y(m_labels->get_num_labels());
00067     for (int32_t i=0; i<y.vlen; i++)
00068         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00069     
00070     slep_options options = slep_options::default_options();
00071     options.n_tasks = m_task_relation->get_num_tasks();
00072     options.tasks_indices = m_task_relation->get_tasks_indices();
00073     options.q = m_q;
00074     options.regularization = m_regularization;
00075     options.termination = m_termination;
00076     options.tolerance = m_tolerance;
00077     options.max_iter = m_max_iter;
00078 
00079     ETaskRelationType relation_type = m_task_relation->get_relation_type();
00080     switch (relation_type)
00081     {
00082         case TASK_GROUP:
00083         {
00084             //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
00085             options.mode = MULTITASK_GROUP;
00086             options.loss = LOGISTIC;
00087             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00088             m_tasks_w = result.w;
00089             m_tasks_c = result.c;
00090         }
00091         break;
00092         case TASK_TREE: 
00093         {
00094             CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00095             SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00096             options.ind_t = ind_t.vector;
00097             options.n_nodes = ind_t.vlen / 3;
00098             options.mode = MULTITASK_TREE;
00099             options.loss = LOGISTIC;
00100             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00101             m_tasks_w = result.w;
00102             m_tasks_c = result.c;
00103         }
00104         break;
00105         default: 
00106             SG_ERROR("Not supported task relation type\n");
00107     }
00108     for (int32_t i=0; i<options.n_tasks; i++)
00109         options.tasks_indices[i].~SGVector<index_t>();
00110     SG_FREE(options.tasks_indices);
00111 
00112     return true;
00113 }
00114 
00115 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
00116 {
00117     ASSERT(features);
00118     ASSERT(m_labels);
00119     
00120     SGVector<float64_t> y(m_labels->get_num_labels());
00121     for (int32_t i=0; i<y.vlen; i++)
00122         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00123     
00124     slep_options options = slep_options::default_options();
00125     options.n_tasks = m_task_relation->get_num_tasks();
00126     options.tasks_indices = tasks;
00127     options.q = m_q;
00128     options.regularization = m_regularization;
00129     options.termination = m_termination;
00130     options.tolerance = m_tolerance;
00131     options.max_iter = m_max_iter;
00132 
00133     ETaskRelationType relation_type = m_task_relation->get_relation_type();
00134     switch (relation_type)
00135     {
00136         case TASK_GROUP:
00137         {
00138             //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
00139             options.mode = MULTITASK_GROUP;
00140             options.loss = LOGISTIC;
00141             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00142             m_tasks_w = result.w;
00143             m_tasks_c = result.c;
00144         }
00145         break;
00146         case TASK_TREE: 
00147         {
00148             CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00149             SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00150             options.ind_t = ind_t.vector;
00151             options.n_nodes = ind_t.vlen / 3;
00152             options.mode = MULTITASK_TREE;
00153             options.loss = LOGISTIC;
00154             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00155             m_tasks_w = result.w;
00156             m_tasks_c = result.c;
00157         }
00158         break;
00159         default: 
00160             SG_ERROR("Not supported task relation type\n");
00161     }
00162     return true;
00163 }
00164 
00165 float64_t CMultitaskLogisticRegression::apply_one(int32_t i)
00166 {
00167     float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
00168     //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
00169     //return 2.0/(1.0+ep) - 1.0;
00170     return dot + m_tasks_c[m_current_task];
00171 }
00172 
00173 int32_t CMultitaskLogisticRegression::get_max_iter() const
00174 {
00175     return m_max_iter;
00176 }
00177 int32_t CMultitaskLogisticRegression::get_regularization() const
00178 {
00179     return m_regularization;
00180 }
00181 int32_t CMultitaskLogisticRegression::get_termination() const
00182 {
00183     return m_termination;
00184 }
00185 float64_t CMultitaskLogisticRegression::get_tolerance() const
00186 {
00187     return m_tolerance;
00188 }
00189 float64_t CMultitaskLogisticRegression::get_z() const
00190 {
00191     return m_z;
00192 }
00193 float64_t CMultitaskLogisticRegression::get_q() const
00194 {
00195     return m_q;
00196 }
00197 
00198 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter)
00199 {
00200     ASSERT(max_iter>=0);
00201     m_max_iter = max_iter;
00202 }
00203 void CMultitaskLogisticRegression::set_regularization(int32_t regularization)
00204 {
00205     ASSERT(regularization==0 || regularization==1);
00206     m_regularization = regularization;
00207 }
00208 void CMultitaskLogisticRegression::set_termination(int32_t termination)
00209 {
00210     ASSERT(termination>=0 && termination<=4);
00211     m_termination = termination;
00212 }
00213 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance)
00214 {
00215     ASSERT(tolerance>0.0);
00216     m_tolerance = tolerance;
00217 }
00218 void CMultitaskLogisticRegression::set_z(float64_t z)
00219 {
00220     m_z = z;
00221 }
00222 void CMultitaskLogisticRegression::set_q(float64_t q)
00223 {
00224     m_q = q;
00225 }
00226 
00227 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation