Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/transfer/multitask/MultitaskL12LogisticRegression.h>
00011 #include <shogun/lib/malsar/malsar_joint_feature_learning.h>
00012 #include <shogun/lib/malsar/malsar_options.h>
00013 #include <shogun/lib/SGVector.h>
00014
00015 namespace shogun
00016 {
00017
00018 CMultitaskL12LogisticRegression::CMultitaskL12LogisticRegression() :
00019 CMultitaskLogisticRegression(), m_rho1(0.0), m_rho2(0.0)
00020 {
00021 init();
00022 }
00023
00024 CMultitaskL12LogisticRegression::CMultitaskL12LogisticRegression(
00025 float64_t rho1, float64_t rho2, CDotFeatures* train_features,
00026 CBinaryLabels* train_labels, CTaskGroup* task_group) :
00027 CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group)
00028 {
00029 set_rho1(rho1);
00030 set_rho2(rho2);
00031 init();
00032 }
00033
00034 void CMultitaskL12LogisticRegression::init()
00035 {
00036 SG_ADD(&m_rho1,"rho1","rho L1/L2 regularization parameter",MS_AVAILABLE);
00037 SG_ADD(&m_rho2,"rho2","rho L2 regularization parameter",MS_AVAILABLE);
00038 }
00039
00040 void CMultitaskL12LogisticRegression::set_rho1(float64_t rho1)
00041 {
00042 m_rho1 = rho1;
00043 }
00044
00045 void CMultitaskL12LogisticRegression::set_rho2(float64_t rho2)
00046 {
00047 m_rho2 = rho2;
00048 }
00049
00050 float64_t CMultitaskL12LogisticRegression::get_rho1() const
00051 {
00052 return m_rho1;
00053 }
00054
00055 float64_t CMultitaskL12LogisticRegression::get_rho2() const
00056 {
00057 return m_rho2;
00058 }
00059
00060 CMultitaskL12LogisticRegression::~CMultitaskL12LogisticRegression()
00061 {
00062 }
00063
00064 bool CMultitaskL12LogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
00065 {
00066 SGVector<float64_t> y(m_labels->get_num_labels());
00067 for (int32_t i=0; i<y.vlen; i++)
00068 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00069
00070 malsar_options options = malsar_options::default_options();
00071 options.termination = m_termination;
00072 options.tolerance = m_tolerance;
00073 options.max_iter = m_max_iter;
00074 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00075 options.tasks_indices = tasks;
00076 #ifdef HAVE_EIGEN3
00077 malsar_result_t model = malsar_joint_feature_learning(
00078 features, y.vector, m_rho1, m_rho2, options);
00079
00080 m_tasks_w = model.w;
00081 m_tasks_c = model.c;
00082 #else
00083 SG_WARNING("Please install Eigen3 to use MultitaskL12LogisticRegression\n");
00084 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks);
00085 m_tasks_c = SGVector<float64_t>(options.n_tasks);
00086 #endif
00087
00088 return true;
00089 }
00090
00091 bool CMultitaskL12LogisticRegression::train_machine(CFeatures* data)
00092 {
00093 if (data && (CDotFeatures*)data)
00094 set_features((CDotFeatures*)data);
00095
00096 ASSERT(features);
00097 ASSERT(m_labels);
00098 ASSERT(m_task_relation);
00099
00100 SGVector<float64_t> y(m_labels->get_num_labels());
00101 for (int32_t i=0; i<y.vlen; i++)
00102 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00103
00104 malsar_options options = malsar_options::default_options();
00105 options.termination = m_termination;
00106 options.tolerance = m_tolerance;
00107 options.max_iter = m_max_iter;
00108 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00109 options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
00110
00111 #ifdef HAVE_EIGEN3
00112 malsar_result_t model = malsar_joint_feature_learning(
00113 features, y.vector, m_rho1, m_rho2, options);
00114
00115 m_tasks_w = model.w;
00116 m_tasks_c = model.c;
00117 #else
00118 SG_WARNING("Please install Eigen3 to use MultitaskL12LogisticRegression\n");
00119 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks);
00120 m_tasks_c = SGVector<float64_t>(options.n_tasks);
00121 #endif
00122
00123 for (int32_t i=0; i<options.n_tasks; i++)
00124 options.tasks_indices[i].~SGVector<index_t>();
00125 SG_FREE(options.tasks_indices);
00126
00127 return true;
00128 }
00129
00130 }