SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MultitaskCompositeMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
11 
12 #include <map>
13 #include <vector>
14 
15 using namespace std;
16 
17 namespace shogun
18 {
19 
20 CMultitaskCompositeMachine::CMultitaskCompositeMachine() :
21  CMachine(), m_machine(NULL), m_features(NULL), m_current_task(0),
22  m_task_group(NULL)
23 {
24  register_parameters();
25 }
26 
28  CMachine* machine, CFeatures* train_features,
29  CLabels* train_labels, CTaskGroup* task_group) :
30  CMachine(), m_machine(NULL), m_features(NULL),
31  m_current_task(0), m_task_group(NULL)
32 {
33  set_machine(machine);
34  set_features(train_features);
35  set_labels(train_labels);
36  set_task_group(task_group);
37  register_parameters();
38 }
39 
41 {
46 }
47 
48 void CMultitaskCompositeMachine::register_parameters()
49 {
50  SG_ADD((CSGObject**)&m_machine, "machine", "machine", MS_AVAILABLE);
51  SG_ADD((CSGObject**)&m_features, "features", "features", MS_NOT_AVAILABLE);
52  SG_ADD((CSGObject**)&m_task_machines, "task_machines", "task machines", MS_NOT_AVAILABLE);
53  SG_ADD((CSGObject**)&m_task_group, "task_group", "task group", MS_NOT_AVAILABLE);
54 }
55 
57 {
58  return m_current_task;
59 }
60 
62 {
63  m_current_task = task;
64 }
65 
67 {
69  return m_task_group;
70 }
71 
73 {
75  SG_REF(task_group);
76  m_task_group = task_group;
77 }
78 
80 {
82  return false;
83 }
84 
86 {
89  if (!m_machine->is_data_locked())
90  m_machine->data_lock(labels,features);
91 
92  int n_tasks = m_task_group->get_num_tasks();
94 
95  m_tasks_indices.clear();
96  for (int32_t i=0; i<n_tasks; i++)
97  {
98  set<index_t> indices_set;
99  SGVector<index_t> task_indices = tasks_indices[i];
100  for (int32_t j=0; j<task_indices.vlen; j++)
101  indices_set.insert(task_indices[j]);
102 
103  m_tasks_indices.push_back(indices_set);
104  }
105 
106  for (int32_t i=0; i<n_tasks; i++)
107  tasks_indices[i].~SGVector<index_t>();
108  SG_FREE(tasks_indices);
109 }
110 
112 {
113  int n_tasks = m_task_group->get_num_tasks();
114  ASSERT((int)m_tasks_indices.size()==n_tasks);
115  vector< vector<index_t> > cutted_task_indices;
116  for (int32_t i=0; i<n_tasks; i++)
117  cutted_task_indices.push_back(vector<index_t>());
118  for (int32_t i=0; i<indices.vlen; i++)
119  {
120  for (int32_t j=0; j<n_tasks; j++)
121  {
122  if (m_tasks_indices[j].count(indices[i]))
123  {
124  cutted_task_indices[j].push_back(indices[i]);
125  break;
126  }
127  }
128  }
129  //SG_UNREF(m_task_machines);
131  for (int32_t i=0; i<n_tasks; i++)
132  {
133  SGVector<index_t> task_indices(cutted_task_indices[i].size());
134  for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++)
135  task_indices[j] = cutted_task_indices[i][j];
136 
137  m_machine->train_locked(task_indices);
139  }
140  return true;
141 }
142 
144 {
146  float64_t result = m->apply_one(i);
147  SG_UNREF(m);
148  return result;
149 }
150 
152 {
153  int n_tasks = m_task_group->get_num_tasks();
154  SGVector<float64_t> result(indices.vlen);
155  result.zero();
156  for (int32_t i=0; i<indices.vlen; i++)
157  {
158  for (int32_t j=0; j<n_tasks; j++)
159  {
160  if (m_tasks_indices[j].count(indices[i]))
161  {
162  set_current_task(j);
163  result[i] = apply_one(indices[i]);
164  break;
165  }
166  }
167  }
168  return new CBinaryLabels(result);
169 }
170 
171 }

SHOGUN Machine Learning Toolbox - Documentation