Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/transfer/multitask/MultitaskCompositeMachine.h>
00011
00012 #include <map>
00013 #include <vector>
00014
00015 using namespace std;
00016
00017 namespace shogun
00018 {
00019
00020 CMultitaskCompositeMachine::CMultitaskCompositeMachine() :
00021 CMachine(), m_machine(NULL), m_features(NULL), m_current_task(0),
00022 m_task_group(NULL)
00023 {
00024 register_parameters();
00025 }
00026
00027 CMultitaskCompositeMachine::CMultitaskCompositeMachine(
00028 CMachine* machine, CFeatures* train_features,
00029 CLabels* train_labels, CTaskGroup* task_group) :
00030 CMachine(), m_machine(NULL), m_features(NULL),
00031 m_current_task(0), m_task_group(NULL)
00032 {
00033 set_machine(machine);
00034 set_features(train_features);
00035 set_labels(train_labels);
00036 set_task_group(task_group);
00037 register_parameters();
00038 }
00039
00040 CMultitaskCompositeMachine::~CMultitaskCompositeMachine()
00041 {
00042 SG_UNREF(m_machine);
00043 SG_UNREF(m_features);
00044 SG_UNREF(m_task_machines);
00045 SG_UNREF(m_task_group);
00046 }
00047
00048 void CMultitaskCompositeMachine::register_parameters()
00049 {
00050 SG_ADD((CSGObject**)&m_machine, "machine", "machine", MS_AVAILABLE);
00051 SG_ADD((CSGObject**)&m_features, "features", "features", MS_NOT_AVAILABLE);
00052 SG_ADD((CSGObject**)&m_task_machines, "task_machines", "task machines", MS_NOT_AVAILABLE);
00053 SG_ADD((CSGObject**)&m_task_group, "task_group", "task group", MS_NOT_AVAILABLE);
00054 }
00055
00056 int32_t CMultitaskCompositeMachine::get_current_task() const
00057 {
00058 return m_current_task;
00059 }
00060
00061 void CMultitaskCompositeMachine::set_current_task(int32_t task)
00062 {
00063 m_current_task = task;
00064 }
00065
00066 CTaskGroup* CMultitaskCompositeMachine::get_task_group() const
00067 {
00068 SG_REF(m_task_group);
00069 return m_task_group;
00070 }
00071
00072 void CMultitaskCompositeMachine::set_task_group(CTaskGroup* task_group)
00073 {
00074 SG_UNREF(m_task_group);
00075 SG_REF(task_group);
00076 m_task_group = task_group;
00077 }
00078
00079 bool CMultitaskCompositeMachine::train_machine(CFeatures* data)
00080 {
00081 SG_NOTIMPLEMENTED;
00082 return false;
00083 }
00084
00085 void CMultitaskCompositeMachine::post_lock(CLabels* labels, CFeatures* features)
00086 {
00087 ASSERT(m_task_group);
00088 set_features(m_features);
00089 if (!m_machine->is_data_locked())
00090 m_machine->data_lock(labels,features);
00091
00092 int n_tasks = m_task_group->get_num_tasks();
00093 SGVector<index_t>* tasks_indices = m_task_group->get_tasks_indices();
00094
00095 m_tasks_indices.clear();
00096 for (int32_t i=0; i<n_tasks; i++)
00097 {
00098 set<index_t> indices_set;
00099 SGVector<index_t> task_indices = tasks_indices[i];
00100 for (int32_t j=0; j<task_indices.vlen; j++)
00101 indices_set.insert(task_indices[j]);
00102
00103 m_tasks_indices.push_back(indices_set);
00104 }
00105
00106 for (int32_t i=0; i<n_tasks; i++)
00107 tasks_indices[i].~SGVector<index_t>();
00108 SG_FREE(tasks_indices);
00109 }
00110
00111 bool CMultitaskCompositeMachine::train_locked(SGVector<index_t> indices)
00112 {
00113 int n_tasks = m_task_group->get_num_tasks();
00114 ASSERT((int)m_tasks_indices.size()==n_tasks);
00115 vector< vector<index_t> > cutted_task_indices;
00116 for (int32_t i=0; i<n_tasks; i++)
00117 cutted_task_indices.push_back(vector<index_t>());
00118 for (int32_t i=0; i<indices.vlen; i++)
00119 {
00120 for (int32_t j=0; j<n_tasks; j++)
00121 {
00122 if (m_tasks_indices[j].count(indices[i]))
00123 {
00124 cutted_task_indices[j].push_back(indices[i]);
00125 break;
00126 }
00127 }
00128 }
00129
00130 m_task_machines = new CDynamicObjectArray();
00131 for (int32_t i=0; i<n_tasks; i++)
00132 {
00133 SGVector<index_t> task_indices(cutted_task_indices[i].size());
00134 for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++)
00135 task_indices[j] = cutted_task_indices[i][j];
00136
00137 m_machine->train_locked(task_indices);
00138 m_task_machines->push_back(m_machine->clone());
00139 }
00140 return true;
00141 }
00142
00143 float64_t CMultitaskCompositeMachine::apply_one(int32_t i)
00144 {
00145 CMachine* m = (CMachine*)(m_task_machines->get_element(m_current_task));
00146 float64_t result = m->apply_one(i);
00147 SG_UNREF(m);
00148 return result;
00149 }
00150
00151 CBinaryLabels* CMultitaskCompositeMachine::apply_locked_binary(SGVector<index_t> indices)
00152 {
00153 int n_tasks = m_task_group->get_num_tasks();
00154 SGVector<float64_t> result(indices.vlen);
00155 result.zero();
00156 for (int32_t i=0; i<indices.vlen; i++)
00157 {
00158 for (int32_t j=0; j<n_tasks; j++)
00159 {
00160 if (m_tasks_indices[j].count(indices[i]))
00161 {
00162 set_current_task(j);
00163 result[i] = apply_one(indices[i]);
00164 break;
00165 }
00166 }
00167 }
00168 return new CBinaryLabels(result);
00169 }
00170
00171 }