MultitaskLeastSquaresRegression.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskLeastSquaresRegression.h>
00011 #include <shogun/transfer/multitask/TaskGroup.h>
00012 #include <shogun/transfer/multitask/TaskTree.h>
00013 #include <shogun/lib/slep/slep_solver.h>
00014 #include <shogun/lib/slep/slep_options.h>
00015 
00016 namespace shogun
00017 {
00018 
00019 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression() :
00020     CMultitaskLinearMachine()
00021 {
00022     initialize_parameters();
00023     register_parameters();
00024 }
00025 
00026 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression(
00027      float64_t z, CDotFeatures* train_features, 
00028      CRegressionLabels* train_labels, CTaskRelation* task_relation) :
00029     CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
00030 {
00031     set_z(z);
00032     initialize_parameters();
00033     register_parameters();
00034 }
00035 
00036 CMultitaskLeastSquaresRegression::~CMultitaskLeastSquaresRegression()
00037 {
00038 }
00039 
00040 void CMultitaskLeastSquaresRegression::register_parameters()
00041 {
00042     SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
00043     SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
00044     SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
00045     SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
00046     SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
00047     SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
00048 }
00049 
00050 void CMultitaskLeastSquaresRegression::initialize_parameters()
00051 {
00052     set_z(0.0);
00053     set_q(2.0);
00054     set_termination(0);
00055     set_regularization(0);
00056     set_tolerance(1e-3);
00057     set_max_iter(1000);
00058 }
00059 
00060 bool CMultitaskLeastSquaresRegression::train_locked_implementation(SGVector<index_t>* tasks)
00061 {
00062     SG_NOTIMPLEMENTED;
00063     return false;
00064 }
00065 
00066 float64_t CMultitaskLeastSquaresRegression::apply_one(int32_t i)
00067 {
00068     float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
00069     return dot + m_tasks_c[m_current_task];
00070 }
00071 
00072 int32_t CMultitaskLeastSquaresRegression::get_max_iter() const
00073 {
00074     return m_max_iter;
00075 }
00076 int32_t CMultitaskLeastSquaresRegression::get_regularization() const
00077 {
00078     return m_regularization;
00079 }
00080 int32_t CMultitaskLeastSquaresRegression::get_termination() const
00081 {
00082     return m_termination;
00083 }
00084 float64_t CMultitaskLeastSquaresRegression::get_tolerance() const
00085 {
00086     return m_tolerance;
00087 }
00088 float64_t CMultitaskLeastSquaresRegression::get_z() const
00089 {
00090     return m_z;
00091 }
00092 float64_t CMultitaskLeastSquaresRegression::get_q() const
00093 {
00094     return m_q;
00095 }
00096 
00097 void CMultitaskLeastSquaresRegression::set_max_iter(int32_t max_iter)
00098 {
00099     ASSERT(max_iter>=0);
00100     m_max_iter = max_iter;
00101 }
00102 void CMultitaskLeastSquaresRegression::set_regularization(int32_t regularization)
00103 {
00104     ASSERT(regularization==0 || regularization==1);
00105     m_regularization = regularization;
00106 }
00107 void CMultitaskLeastSquaresRegression::set_termination(int32_t termination)
00108 {
00109     ASSERT(termination>=0 && termination<=4);
00110     m_termination = termination;
00111 }
00112 void CMultitaskLeastSquaresRegression::set_tolerance(float64_t tolerance)
00113 {
00114     ASSERT(tolerance>0.0);
00115     m_tolerance = tolerance;
00116 }
00117 void CMultitaskLeastSquaresRegression::set_z(float64_t z)
00118 {
00119     m_z = z;
00120 }
00121 void CMultitaskLeastSquaresRegression::set_q(float64_t q)
00122 {
00123     m_q = q;
00124 }
00125 
00126 bool CMultitaskLeastSquaresRegression::train_machine(CFeatures* data)
00127 {
00128     if (data && (CDotFeatures*)data)
00129         set_features((CDotFeatures*)data);
00130 
00131     ASSERT(features);
00132     ASSERT(m_labels);
00133 
00134     SGVector<float64_t> y = ((CRegressionLabels*)m_labels)->get_labels();
00135     
00136     slep_options options = slep_options::default_options();
00137     options.n_tasks = m_task_relation->get_num_tasks();
00138     options.tasks_indices = m_task_relation->get_tasks_indices();
00139     options.q = m_q;
00140     options.regularization = m_regularization;
00141     options.termination = m_termination;
00142     options.tolerance = m_tolerance;
00143     options.max_iter = m_max_iter;
00144 
00145     ETaskRelationType relation_type = m_task_relation->get_relation_type();
00146     switch (relation_type)
00147     {
00148         case TASK_GROUP:
00149         {
00150             //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
00151             options.mode = MULTITASK_GROUP;
00152             options.loss = LEAST_SQUARES;
00153             m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
00154             m_tasks_c = SGVector<float64_t>(options.n_tasks);
00155             m_tasks_c.zero();
00156         }
00157         break;
00158         case TASK_TREE: 
00159         {
00160             CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00161             SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00162             options.ind_t = ind_t.vector;
00163             options.n_nodes = ind_t.vlen/3;
00164             options.mode = MULTITASK_TREE;
00165             options.loss = LEAST_SQUARES;
00166             m_tasks_w = slep_solver(features, y.vector, m_z, options).w;
00167             m_tasks_c = SGVector<float64_t>(options.n_tasks);
00168             m_tasks_c.zero();
00169         }
00170         break;
00171         default: 
00172             SG_ERROR("Not supported task relation type\n");
00173     }
00174 
00175     for (int32_t i=0; i<options.n_tasks; i++)
00176         options.tasks_indices[i].~SGVector<index_t>();
00177     SG_FREE(options.tasks_indices);
00178 
00179     return true;
00180 }
00181 
00182 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation