SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
MultitaskLinearMachine.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Copyright (C) 2012 Sergey Lisitsyn
8  */
9 
10 #ifndef MULTITASKMACHINE_H_
11 #define MULTITASKMACHINE_H_
12 
13 #include <shogun/lib/config.h>
14 #ifdef USE_GPL_SHOGUN
20 
21 #include <vector>
22 #include <set>
23 
24 namespace shogun
25 {
29 class CMultitaskLinearMachine : public CLinearMachine
30 {
31 
32  public:
34  CMultitaskLinearMachine();
35 
42  CMultitaskLinearMachine(
43  CDotFeatures* training_data,
44  CLabels* training_labels, CTaskRelation* task_relation);
45 
47  virtual ~CMultitaskLinearMachine();
48 
50  virtual const char* get_name() const
51  {
52  return "MultitaskLinearMachine";
53  }
54 
58  int32_t get_current_task() const;
59 
63  void set_current_task(int32_t task);
64 
69  virtual SGVector<float64_t> get_w() const;
70 
75  virtual void set_w(const SGVector<float64_t> src_w);
76 
81  virtual void set_bias(float64_t b);
82 
87  virtual float64_t get_bias();
88 
92  CTaskRelation* get_task_relation() const;
93 
97  void set_task_relation(CTaskRelation* task_relation);
98 
100  virtual bool supports_locking() const { return true; }
101 
103  virtual void post_lock(CLabels* labels, CFeatures* features_);
104 
106  virtual bool train_locked(SGVector<index_t> indices);
107 
109  virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices);
110 
112  virtual float64_t apply_one(int32_t i);
113 
114  protected:
115 
117  virtual SGVector<float64_t> apply_get_outputs(CFeatures* data=NULL);
118 
120  virtual bool train_machine(CFeatures* data=NULL);
121 
123  virtual bool train_locked_implementation(SGVector<index_t>* tasks);
124 
126  SGVector<index_t>* get_subset_tasks_indices();
127 
128  private:
129 
131  void register_parameters();
132 
133  protected:
134 
136  int32_t m_current_task;
137 
139  CTaskRelation* m_task_relation;
140 
142  SGMatrix<float64_t> m_tasks_w;
143 
145  SGVector<float64_t> m_tasks_c;
146 
148  std::vector< std::set<index_t> > m_tasks_indices;
149 
150 };
151 }
152 #endif //USE_GPL_SHOGUN
153 #endif
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18

SHOGUN Machine Learning Toolbox - Documentation