MultitaskLinearMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #ifndef  MULTITASKMACHINE_H_
00011 #define  MULTITASKMACHINE_H_
00012 
00013 #include <shogun/lib/config.h>
00014 #include <shogun/machine/LinearMachine.h>
00015 #include <shogun/transfer/multitask/TaskRelation.h>
00016 #include <shogun/transfer/multitask/TaskGroup.h>
00017 #include <shogun/transfer/multitask/TaskTree.h>
00018 #include <shogun/transfer/multitask/Task.h>
00019 
00020 #include <vector>
00021 #include <set>
00022 
00023 using namespace std;
00024 
00025 namespace shogun
00026 {
00030 class CMultitaskLinearMachine : public CLinearMachine
00031 {
00032 
00033     public:
00035         CMultitaskLinearMachine();
00036 
00043         CMultitaskLinearMachine(
00044              CDotFeatures* training_data, 
00045              CLabels* training_labels, CTaskRelation* task_relation);
00046 
00048         virtual ~CMultitaskLinearMachine();
00049 
00051         virtual const char* get_name() const 
00052         {
00053             return "MultitaskLinearMachine";
00054         }
00055 
00059         int32_t get_current_task() const;
00060 
00064         void set_current_task(int32_t task);
00065         
00070         virtual SGVector<float64_t> get_w() const;
00071 
00076         virtual void set_w(const SGVector<float64_t> src_w);
00077 
00082         virtual void set_bias(float64_t b);
00083 
00088         virtual float64_t get_bias();
00089 
00093         CTaskRelation* get_task_relation() const;
00094 
00098         void set_task_relation(CTaskRelation* task_relation);
00099 
00101         virtual bool supports_locking() const { return true; }
00102 
00104         virtual void post_lock(CLabels* labels, CFeatures* features_);
00105 
00107         virtual bool train_locked(SGVector<index_t> indices);
00108 
00110         virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices);
00111 
00113         virtual float64_t apply_one(int32_t i);
00114 
00115     protected:
00116 
00118         virtual SGVector<float64_t> apply_get_outputs(CFeatures* data=NULL);
00119 
00121         virtual bool train_machine(CFeatures* data=NULL);
00122 
00124         virtual bool train_locked_implementation(SGVector<index_t>* tasks);
00125 
00127         SGVector<index_t>* get_subset_tasks_indices();
00128 
00129     private:
00130 
00132         void register_parameters();
00133 
00134     protected:
00135 
00137         int32_t m_current_task;
00138 
00140         CTaskRelation* m_task_relation;
00141 
00143         SGMatrix<float64_t> m_tasks_w;
00144         
00146         SGVector<float64_t> m_tasks_c;
00147 
00149         vector< set<index_t> > m_tasks_indices;
00150 
00151 };
00152 }
00153 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation