00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #ifndef MULTITASKMACHINE_H_ 00011 #define MULTITASKMACHINE_H_ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/machine/LinearMachine.h> 00015 #include <shogun/transfer/multitask/TaskRelation.h> 00016 #include <shogun/transfer/multitask/TaskGroup.h> 00017 #include <shogun/transfer/multitask/TaskTree.h> 00018 #include <shogun/transfer/multitask/Task.h> 00019 00020 #include <vector> 00021 #include <set> 00022 00023 using namespace std; 00024 00025 namespace shogun 00026 { 00030 class CMultitaskLinearMachine : public CLinearMachine 00031 { 00032 00033 public: 00035 CMultitaskLinearMachine(); 00036 00043 CMultitaskLinearMachine( 00044 CDotFeatures* training_data, 00045 CLabels* training_labels, CTaskRelation* task_relation); 00046 00048 virtual ~CMultitaskLinearMachine(); 00049 00051 virtual const char* get_name() const 00052 { 00053 return "MultitaskLinearMachine"; 00054 } 00055 00059 int32_t get_current_task() const; 00060 00064 void set_current_task(int32_t task); 00065 00070 virtual SGVector<float64_t> get_w() const; 00071 00076 virtual void set_w(const SGVector<float64_t> src_w); 00077 00082 virtual void set_bias(float64_t b); 00083 00088 virtual float64_t get_bias(); 00089 00093 CTaskRelation* get_task_relation() const; 00094 00098 void set_task_relation(CTaskRelation* task_relation); 00099 00101 virtual bool supports_locking() const { return true; } 00102 00104 virtual void post_lock(CLabels* labels, CFeatures* features_); 00105 00107 virtual bool train_locked(SGVector<index_t> indices); 00108 00110 virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices); 00111 00113 virtual float64_t apply_one(int32_t i); 00114 00115 protected: 00116 00118 virtual SGVector<float64_t> apply_get_outputs(CFeatures* data=NULL); 00119 00121 virtual bool train_machine(CFeatures* data=NULL); 00122 00124 virtual bool train_locked_implementation(SGVector<index_t>* tasks); 00125 00127 SGVector<index_t>* get_subset_tasks_indices(); 00128 00129 private: 00130 00132 void register_parameters(); 00133 00134 protected: 00135 00137 int32_t m_current_task; 00138 00140 CTaskRelation* m_task_relation; 00141 00143 SGMatrix<float64_t> m_tasks_w; 00144 00146 SGVector<float64_t> m_tasks_c; 00147 00149 vector< set<index_t> > m_tasks_indices; 00150 00151 }; 00152 } 00153 #endif