MultitaskTraceLogisticRegression.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskTraceLogisticRegression.h>
00011 #include <shogun/lib/malsar/malsar_low_rank.h>
00012 #include <shogun/lib/malsar/malsar_options.h>
00013 #include <shogun/lib/IndexBlockGroup.h>
00014 #include <shogun/lib/SGVector.h>
00015 
00016 namespace shogun
00017 {
00018 
00019 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression() :
00020     CMultitaskLogisticRegression(), m_rho(0.0)
00021 {
00022     init();
00023 }
00024 
00025 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression(
00026      float64_t rho, CDotFeatures* train_features, 
00027      CBinaryLabels* train_labels, CTaskGroup* task_group) :
00028     CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group)
00029 {
00030     set_rho(rho);
00031     init();
00032 }
00033 
00034 void CMultitaskTraceLogisticRegression::init()
00035 {
00036     SG_ADD(&m_rho,"rho","rho",MS_AVAILABLE);
00037 }
00038 
00039 void CMultitaskTraceLogisticRegression::set_rho(float64_t rho)
00040 {
00041     m_rho = rho;
00042 }
00043 
00044 float64_t CMultitaskTraceLogisticRegression::get_rho() const
00045 {
00046     return m_rho;
00047 }
00048 
00049 CMultitaskTraceLogisticRegression::~CMultitaskTraceLogisticRegression()
00050 {
00051 }
00052 
00053 bool CMultitaskTraceLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
00054 {
00055     SGVector<float64_t> y(m_labels->get_num_labels());
00056     for (int32_t i=0; i<y.vlen; i++)
00057         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00058     
00059     malsar_options options = malsar_options::default_options();
00060     options.termination = m_termination;
00061     options.tolerance = m_tolerance;
00062     options.max_iter = m_max_iter;
00063     options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00064     options.tasks_indices = tasks;
00065 
00066 #ifdef HAVE_EIGEN3
00067     malsar_result_t model = malsar_low_rank(
00068         features, y.vector, m_rho, options);
00069 
00070     m_tasks_w = model.w;
00071     m_tasks_c = model.c;
00072 #else
00073     SG_WARNING("Please install Eigen3 to use MultitaskTraceLogisticRegression\n");
00074     m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 
00075     m_tasks_c = SGVector<float64_t>(options.n_tasks); 
00076 #endif
00077     return true;
00078 }
00079 
00080 bool CMultitaskTraceLogisticRegression::train_machine(CFeatures* data)
00081 {
00082     if (data && (CDotFeatures*)data)
00083         set_features((CDotFeatures*)data);
00084 
00085     ASSERT(features);
00086     ASSERT(m_labels);
00087     ASSERT(m_task_relation);
00088 
00089     SGVector<float64_t> y(m_labels->get_num_labels());
00090     for (int32_t i=0; i<y.vlen; i++)
00091         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00092     
00093     malsar_options options = malsar_options::default_options();
00094     options.termination = m_termination;
00095     options.tolerance = m_tolerance;
00096     options.max_iter = m_max_iter;
00097     options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00098     options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
00099 
00100 #ifdef HAVE_EIGEN3
00101     malsar_result_t model = malsar_low_rank(
00102         features, y.vector, m_rho, options);
00103 
00104     m_tasks_w = model.w;
00105     m_tasks_c = model.c;
00106 #else
00107     SG_WARNING("Please install Eigen3 to use MultitaskTraceLogisticRegression\n");
00108     m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 
00109     m_tasks_c = SGVector<float64_t>(options.n_tasks); 
00110 #endif
00111 
00112     for (int32_t i=0; i<options.n_tasks; i++)
00113         options.tasks_indices[i].~SGVector<index_t>();
00114     SG_FREE(options.tasks_indices);
00115 
00116     return true;
00117 }
00118 
00119 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation