00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #ifndef MULTITASKCOMPOSITEMACHINE_H_ 00011 #define MULTITASKCOMPOSITEMACHINE_H_ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/machine/Machine.h> 00015 #include <shogun/features/DotFeatures.h> 00016 #include <shogun/transfer/multitask/TaskRelation.h> 00017 #include <shogun/transfer/multitask/TaskGroup.h> 00018 #include <shogun/transfer/multitask/TaskTree.h> 00019 #include <shogun/transfer/multitask/Task.h> 00020 00021 #include <vector> 00022 #include <set> 00023 00024 using namespace std; 00025 00026 namespace shogun 00027 { 00032 class CMultitaskCompositeMachine : public CMachine 00033 { 00034 00035 public: 00036 MACHINE_PROBLEM_TYPE(PT_BINARY) 00037 00038 00039 CMultitaskCompositeMachine(); 00040 00048 CMultitaskCompositeMachine( 00049 CMachine* machine, CFeatures* training_data, 00050 CLabels* training_labels, CTaskGroup* task_group); 00051 00053 virtual ~CMultitaskCompositeMachine(); 00054 00056 virtual const char* get_name() const 00057 { 00058 return "MultitaskCompositeMachine"; 00059 } 00060 00064 int32_t get_current_task() const; 00065 00069 void set_current_task(int32_t task); 00070 00074 CTaskGroup* get_task_group() const; 00075 00079 void set_task_group(CTaskGroup* task_group); 00080 00082 virtual bool supports_locking() const { return true; } 00083 00085 virtual void post_lock(CLabels* labels, CFeatures* features); 00086 00088 virtual bool train_locked(SGVector<index_t> indices); 00089 00091 virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices); 00092 00097 virtual void set_features(CFeatures* features) 00098 { 00099 SG_REF(features); 00100 SG_UNREF(m_features); 00101 m_features = features; 00102 } 00103 00108 virtual CFeatures* get_features() const 00109 { 00110 SG_REF(m_features); 00111 return m_features; 00112 } 00113 00118 virtual void set_machine(CMachine* machine) 00119 { 00120 SG_REF(machine); 00121 SG_UNREF(m_machine); 00122 m_machine = machine; 00123 } 00124 00129 virtual CMachine* get_machine() const 00130 { 00131 SG_REF(m_machine); 00132 return m_machine; 00133 } 00134 00136 virtual float64_t apply_one(int32_t vec_idx); 00137 00138 protected: 00139 00141 virtual bool train_machine(CFeatures* data=NULL); 00142 00143 private: 00144 00146 void register_parameters(); 00147 00148 protected: 00149 00151 CMachine* m_machine; 00152 00154 CFeatures* m_features; 00155 00157 int32_t m_current_task; 00158 00160 CTaskGroup* m_task_group; 00161 00163 CDynamicObjectArray* m_task_machines; 00164 00166 vector< set<index_t> > m_tasks_indices; 00167 00168 }; 00169 } 00170 #endif