MultitaskCompositeMachine.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskCompositeMachine.h>
00011 
00012 #include <map>
00013 #include <vector>
00014 
00015 using namespace std;
00016 
00017 namespace shogun
00018 {
00019 
00020 CMultitaskCompositeMachine::CMultitaskCompositeMachine() :
00021     CMachine(), m_machine(NULL), m_features(NULL), m_current_task(0), 
00022     m_task_group(NULL)
00023 {
00024     register_parameters();
00025 }
00026 
00027 CMultitaskCompositeMachine::CMultitaskCompositeMachine(
00028      CMachine* machine, CFeatures* train_features, 
00029      CLabels* train_labels, CTaskGroup* task_group) :
00030     CMachine(), m_machine(NULL), m_features(NULL), 
00031     m_current_task(0), m_task_group(NULL)
00032 {
00033     set_machine(machine);
00034     set_features(train_features);
00035     set_labels(train_labels);
00036     set_task_group(task_group);
00037     register_parameters();
00038 }
00039 
00040 CMultitaskCompositeMachine::~CMultitaskCompositeMachine()
00041 {
00042     SG_UNREF(m_machine);
00043     SG_UNREF(m_features);
00044     SG_UNREF(m_task_machines);
00045     SG_UNREF(m_task_group);
00046 }
00047 
00048 void CMultitaskCompositeMachine::register_parameters()
00049 {
00050     SG_ADD((CSGObject**)&m_machine, "machine", "machine", MS_AVAILABLE);
00051     SG_ADD((CSGObject**)&m_features, "features", "features", MS_NOT_AVAILABLE);
00052     SG_ADD((CSGObject**)&m_task_machines, "task_machines", "task machines", MS_NOT_AVAILABLE);
00053     SG_ADD((CSGObject**)&m_task_group, "task_group", "task group", MS_NOT_AVAILABLE);
00054 }
00055 
00056 int32_t CMultitaskCompositeMachine::get_current_task() const
00057 {
00058     return m_current_task;
00059 }
00060 
00061 void CMultitaskCompositeMachine::set_current_task(int32_t task)
00062 {
00063     m_current_task = task;
00064 }
00065 
00066 CTaskGroup* CMultitaskCompositeMachine::get_task_group() const
00067 {
00068     SG_REF(m_task_group);
00069     return m_task_group;
00070 }
00071 
00072 void CMultitaskCompositeMachine::set_task_group(CTaskGroup* task_group)
00073 {
00074     SG_UNREF(m_task_group);
00075     SG_REF(task_group);
00076     m_task_group = task_group;
00077 }
00078 
00079 bool CMultitaskCompositeMachine::train_machine(CFeatures* data)
00080 {
00081     SG_NOTIMPLEMENTED;
00082     return false;
00083 }
00084 
00085 void CMultitaskCompositeMachine::post_lock(CLabels* labels, CFeatures* features)
00086 {
00087     ASSERT(m_task_group);
00088     set_features(m_features);
00089     if (!m_machine->is_data_locked())
00090         m_machine->data_lock(labels,features);
00091 
00092     int n_tasks = m_task_group->get_num_tasks();
00093     SGVector<index_t>* tasks_indices = m_task_group->get_tasks_indices();
00094 
00095     m_tasks_indices.clear();
00096     for (int32_t i=0; i<n_tasks; i++)
00097     {
00098         set<index_t> indices_set;
00099         SGVector<index_t> task_indices = tasks_indices[i];
00100         for (int32_t j=0; j<task_indices.vlen; j++)
00101             indices_set.insert(task_indices[j]);
00102 
00103         m_tasks_indices.push_back(indices_set);
00104     }
00105 
00106     for (int32_t i=0; i<n_tasks; i++)
00107         tasks_indices[i].~SGVector<index_t>();
00108     SG_FREE(tasks_indices);
00109 }
00110 
00111 bool CMultitaskCompositeMachine::train_locked(SGVector<index_t> indices)
00112 {
00113     int n_tasks = m_task_group->get_num_tasks();
00114     ASSERT((int)m_tasks_indices.size()==n_tasks);
00115     vector< vector<index_t> > cutted_task_indices;
00116     for (int32_t i=0; i<n_tasks; i++)
00117         cutted_task_indices.push_back(vector<index_t>());
00118     for (int32_t i=0; i<indices.vlen; i++)
00119     {
00120         for (int32_t j=0; j<n_tasks; j++)
00121         {
00122             if (m_tasks_indices[j].count(indices[i]))
00123             {
00124                 cutted_task_indices[j].push_back(indices[i]);
00125                 break;
00126             }
00127         }
00128     }
00129     //SG_UNREF(m_task_machines);
00130     m_task_machines = new CDynamicObjectArray();
00131     for (int32_t i=0; i<n_tasks; i++)
00132     {
00133         SGVector<index_t> task_indices(cutted_task_indices[i].size());
00134         for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++)
00135             task_indices[j] = cutted_task_indices[i][j];
00136 
00137         m_machine->train_locked(task_indices);
00138         m_task_machines->push_back(m_machine->clone());
00139     }
00140     return true;
00141 }
00142 
00143 float64_t CMultitaskCompositeMachine::apply_one(int32_t i)
00144 {
00145     CMachine* m = (CMachine*)(m_task_machines->get_element(m_current_task));
00146     float64_t result = m->apply_one(i);
00147     SG_UNREF(m);
00148     return result;
00149 }
00150 
00151 CBinaryLabels* CMultitaskCompositeMachine::apply_locked_binary(SGVector<index_t> indices)
00152 {
00153     int n_tasks = m_task_group->get_num_tasks();
00154     SGVector<float64_t> result(indices.vlen);
00155     result.zero();
00156     for (int32_t i=0; i<indices.vlen; i++)
00157     {
00158         for (int32_t j=0; j<n_tasks; j++)
00159         {
00160             if (m_tasks_indices[j].count(indices[i]))
00161             {
00162                 set_current_task(j);
00163                 result[i] = apply_one(indices[i]);
00164                 break;
00165             }
00166         }
00167     }
00168     return new CBinaryLabels(result);
00169 }
00170 
00171 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation