SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
MultitaskTraceLogisticRegression.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
10 
12 #ifdef USE_GPL_SHOGUN
16 #include <shogun/lib/SGVector.h>
18 
19 namespace shogun
20 {
21 
22 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression() :
23  CMultitaskLogisticRegression(), m_rho(0.0)
24 {
25  init();
26 }
27 
28 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression(
29  float64_t rho, CDotFeatures* train_features,
30  CBinaryLabels* train_labels, CTaskGroup* task_group) :
31  CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group)
32 {
33  set_rho(rho);
34  init();
35 }
36 
37 void CMultitaskTraceLogisticRegression::init()
38 {
39  SG_ADD(&m_rho,"rho","rho",MS_AVAILABLE);
40 }
41 
42 void CMultitaskTraceLogisticRegression::set_rho(float64_t rho)
43 {
44  m_rho = rho;
45 }
46 
47 float64_t CMultitaskTraceLogisticRegression::get_rho() const
48 {
49  return m_rho;
50 }
51 
52 CMultitaskTraceLogisticRegression::~CMultitaskTraceLogisticRegression()
53 {
54 }
55 
56 bool CMultitaskTraceLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
57 {
58  SGVector<float64_t> y(m_labels->get_num_labels());
59  for (int32_t i=0; i<y.vlen; i++)
60  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
61 
62  malsar_options options = malsar_options::default_options();
63  options.termination = m_termination;
64  options.tolerance = m_tolerance;
65  options.max_iter = m_max_iter;
66  options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
67  options.tasks_indices = tasks;
68 
69  malsar_result_t model = malsar_low_rank(
70  features, y.vector, m_rho, options);
71 
72  m_tasks_w = model.w;
73  m_tasks_c = model.c;
74  return true;
75 }
76 
77 bool CMultitaskTraceLogisticRegression::train_machine(CFeatures* data)
78 {
79  if (data && (CDotFeatures*)data)
80  set_features((CDotFeatures*)data);
81 
82  ASSERT(features)
83  ASSERT(m_labels)
84  ASSERT(m_task_relation)
85 
86  SGVector<float64_t> y(m_labels->get_num_labels());
87  for (int32_t i=0; i<y.vlen; i++)
88  y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
89 
90  malsar_options options = malsar_options::default_options();
91  options.termination = m_termination;
92  options.tolerance = m_tolerance;
93  options.max_iter = m_max_iter;
94  options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
95  options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
96 
97  malsar_result_t model = malsar_low_rank(
98  features, y.vector, m_rho, options);
99 
100  m_tasks_w = model.w;
101  m_tasks_c = model.c;
102 
103  SG_FREE(options.tasks_indices);
104 
105  return true;
106 }
107 
108 }
109 
110 #endif //USE_GPL_SHOGUN
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_ADD(...)
Definition: SGObject.h:81

SHOGUN Machine Learning Toolbox - Documentation