22 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression() :
23 CMultitaskLogisticRegression(), m_rho1(0.0), m_rho2(0.0)
27 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression(
29 CBinaryLabels* train_labels, CTaskGroup* task_group, int32_t n_clusters) :
30 CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group)
34 set_num_clusters(n_clusters);
37 int32_t CMultitaskClusteredLogisticRegression::get_rho1()
const
42 int32_t CMultitaskClusteredLogisticRegression::get_rho2()
const
47 void CMultitaskClusteredLogisticRegression::set_rho1(
float64_t rho1)
52 void CMultitaskClusteredLogisticRegression::set_rho2(
float64_t rho2)
57 int32_t CMultitaskClusteredLogisticRegression::get_num_clusters()
const
59 return m_num_clusters;
62 void CMultitaskClusteredLogisticRegression::set_num_clusters(int32_t num_clusters)
64 m_num_clusters = num_clusters;
67 CMultitaskClusteredLogisticRegression::~CMultitaskClusteredLogisticRegression()
71 bool CMultitaskClusteredLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
73 SGVector<float64_t> y(m_labels->get_num_labels());
74 for (int32_t i=0; i<y.vlen; i++)
75 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
77 malsar_options options = malsar_options::default_options();
78 options.termination = m_termination;
79 options.tolerance = m_tolerance;
80 options.max_iter = m_max_iter;
81 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
82 options.tasks_indices = tasks;
83 options.n_clusters = m_num_clusters;
86 malsar_result_t model = malsar_clustered(
87 features, y.vector, m_rho1, m_rho2, options);
92 SG_WARNING(
"Clustered LR is unstable with C++11\n")
93 m_tasks_w = SGMatrix<
float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks);
94 m_tasks_w.set_const(0);
95 m_tasks_c = SGVector<
float64_t>(options.n_tasks);
96 m_tasks_c.set_const(0);
101 bool CMultitaskClusteredLogisticRegression::train_machine(CFeatures* data)
103 if (data && (CDotFeatures*)data)
104 set_features((CDotFeatures*)data);
110 SGVector<
float64_t> y(m_labels->get_num_labels());
111 for (int32_t i=0; i<y.vlen; i++)
112 y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
114 malsar_options options = malsar_options::default_options();
115 options.termination = m_termination;
116 options.tolerance = m_tolerance;
117 options.max_iter = m_max_iter;
118 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
119 options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
120 options.n_clusters = m_num_clusters;
123 malsar_result_t model = malsar_clustered(
124 features, y.vector, m_rho1, m_rho2, options);
129 SG_WARNING(
"Clustered LR is unstable with C++11\n")
130 m_tasks_w = SGMatrix<
float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks);
131 m_tasks_w.set_const(0);
132 m_tasks_c = SGVector<
float64_t>(options.n_tasks);
133 m_tasks_c.set_const(0);
136 SG_FREE(options.tasks_indices);
143 #endif //USE_GPL_SHOGUN
all of classes and functions are contained in the shogun namespace