SHOGUN  6.1.3
Machine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Written (W) 2011-2012 Heiko Strathmann
9  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
10  */
11 
12 #include <rxcpp/rx-lite.hpp>
13 #include <shogun/base/init.h>
14 #include <shogun/lib/Signal.h>
15 #include <shogun/machine/Machine.h>
16 
17 using namespace shogun;
18 
20  : CSGObject(), m_max_train_time(0), m_labels(NULL), m_solver_type(ST_AUTO),
21  m_cancel_computation(false), m_pause_computation_flag(false)
22 {
23  m_data_locked=false;
25 
26  SG_ADD(&m_max_train_time, "max_train_time",
27  "Maximum training time.", MS_NOT_AVAILABLE);
28  SG_ADD((machine_int_t*) &m_solver_type, "solver_type",
29  "Type of solver.", MS_NOT_AVAILABLE);
30  SG_ADD((CSGObject**) &m_labels, "labels",
31  "Labels to be used.", MS_NOT_AVAILABLE);
32  SG_ADD(&m_store_model_features, "store_model_features",
33  "Should feature data of model be stored after training?", MS_NOT_AVAILABLE);
34  SG_ADD(&m_data_locked, "data_locked",
35  "Indicates whether data is locked", MS_NOT_AVAILABLE);
36 }
37 
39 {
41 }
42 
44 {
45  /* not allowed to train on locked data */
46  if (m_data_locked)
47  {
48  SG_ERROR("%s::train data_lock() was called, only train_locked() is"
49  " possible. Call data_unlock if you want to call train()\n",
50  get_name());
51  }
52 
54  {
55  if (m_labels == NULL)
56  SG_ERROR("%s@%p: No labels given", get_name(), this)
57 
59  }
60 
61  auto sub = connect_to_signal_handler();
62  bool result = train_machine(data);
63  sub.unsubscribe();
65 
68 
69  return result;
70 }
71 
73 {
74  if (lab != NULL)
75  if (!is_label_valid(lab))
76  SG_ERROR("Invalid label for %s", get_name())
77 
78  SG_REF(lab);
80  m_labels = lab;
81 }
82 
84 {
86  return m_labels;
87 }
88 
90 {
91  m_max_train_time = t;
92 }
93 
95 {
96  return m_max_train_time;
97 }
98 
100 {
101  return CT_NONE;
102 }
103 
105 {
106  m_solver_type = st;
107 }
108 
110 {
111  return m_solver_type;
112 }
113 
115 {
116  m_store_model_features = store_model;
117 }
118 
119 void CMachine::data_lock(CLabels* labs, CFeatures* features)
120 {
121  SG_DEBUG("entering %s::data_lock\n", get_name())
122  if (!supports_locking())
123  {
124  {
125  SG_ERROR("%s::data_lock(): Machine does not support data locking!\n",
126  get_name());
127  }
128  }
129 
130  if (!labs)
131  {
132  SG_ERROR("%s::data_lock() is not possible will NULL labels!\n",
133  get_name());
134  }
135 
136  /* first set labels */
137  set_labels(labs);
138 
139  if (m_data_locked)
140  {
141  SG_ERROR("%s::data_lock() was already called. Dont lock twice!",
142  get_name());
143  }
144 
145  m_data_locked=true;
146  post_lock(labs,features);
147  SG_DEBUG("leaving %s::data_lock\n", get_name())
148 }
149 
151 {
152  SG_DEBUG("entering %s::data_lock\n", get_name())
153  if (m_data_locked)
154  m_data_locked=false;
155 
156  SG_DEBUG("leaving %s::data_lock\n", get_name())
157 }
158 
160 {
161  SG_DEBUG("entering %s::apply(%s at %p)\n",
162  get_name(), data ? data->get_name() : "NULL", data);
163 
164  CLabels* result=NULL;
165 
166  switch (get_machine_problem_type())
167  {
168  case PT_BINARY:
169  result=apply_binary(data);
170  break;
171  case PT_REGRESSION:
172  result=apply_regression(data);
173  break;
174  case PT_MULTICLASS:
175  result=apply_multiclass(data);
176  break;
177  case PT_STRUCTURED:
178  result=apply_structured(data);
179  break;
180  case PT_LATENT:
181  result=apply_latent(data);
182  break;
183  default:
184  SG_ERROR("Unknown problem type")
185  break;
186  }
187 
188  SG_DEBUG("leaving %s::apply(%s at %p)\n",
189  get_name(), data ? data->get_name() : "NULL", data);
190 
191  return result;
192 }
193 
195 {
196  switch (get_machine_problem_type())
197  {
198  case PT_BINARY:
199  return apply_locked_binary(indices);
200  case PT_REGRESSION:
201  return apply_locked_regression(indices);
202  case PT_MULTICLASS:
203  return apply_locked_multiclass(indices);
204  case PT_STRUCTURED:
205  return apply_locked_structured(indices);
206  case PT_LATENT:
207  return apply_locked_latent(indices);
208  default:
209  SG_ERROR("Unknown problem type")
210  break;
211  }
212  return NULL;
213 }
214 
216 {
217  SG_ERROR("This machine does not support apply_binary()\n")
218  return NULL;
219 }
220 
222 {
223  SG_ERROR("This machine does not support apply_regression()\n")
224  return NULL;
225 }
226 
228 {
229  SG_ERROR("This machine does not support apply_multiclass()\n")
230  return NULL;
231 }
232 
234 {
235  SG_ERROR("This machine does not support apply_structured()\n")
236  return NULL;
237 }
238 
240 {
241  SG_ERROR("This machine does not support apply_latent()\n")
242  return NULL;
243 }
244 
246 {
247  SG_ERROR("apply_locked_binary(SGVector<index_t>) is not yet implemented "
248  "for %s\n", get_name());
249  return NULL;
250 }
251 
253 {
254  SG_ERROR("apply_locked_regression(SGVector<index_t>) is not yet implemented "
255  "for %s\n", get_name());
256  return NULL;
257 }
258 
260 {
261  SG_ERROR("apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
262  "for %s\n", get_name());
263  return NULL;
264 }
265 
267 {
268  SG_ERROR("apply_locked_structured(SGVector<index_t>) is not yet implemented "
269  "for %s\n", get_name());
270  return NULL;
271 }
272 
274 {
275  SG_ERROR("apply_locked_latent(SGVector<index_t>) is not yet implemented "
276  "for %s\n", get_name());
277  return NULL;
278 }
279 
281 {
282  // Subscribe this algorithm to the signal handler
283  auto subscriber = rxcpp::make_subscriber<int>(
284  [this](int i) {
285  if (i == SG_PAUSE_COMP)
286  this->on_pause();
287  else
288  this->on_next();
289  },
290  [this]() { this->on_complete(); });
291  return get_global_signal()->get_observable()->subscribe(subscriber);
292 }
virtual const char * get_name() const =0
EMachineType
Definition: Machine.h:36
void set_max_train_time(float64_t t)
Definition: Machine.cpp:89
Base class of the labels used in Structured Output (SO) problems.
Real Labels are real-valued labels.
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:194
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
ESolverType
Definition: Machine.h:101
float64_t m_max_train_time
Definition: Machine.h:433
CLabels * m_labels
Definition: Machine.h:436
void reset_computation_variables()
Definition: Machine.h:403
#define SG_ERROR(...)
Definition: SGIO.h:128
ESolverType m_solver_type
Definition: Machine.h:439
bool m_data_locked
Definition: Machine.h:445
virtual CStructuredLabels * apply_locked_structured(SGVector< index_t > indices)
Definition: Machine.cpp:266
virtual bool train_machine(CFeatures *data=NULL)
Definition: Machine.h:361
virtual void on_complete()
Definition: Machine.h:427
bool m_store_model_features
Definition: Machine.h:442
virtual const char * get_name() const
Definition: Machine.h:348
#define SG_REF(x)
Definition: SGObject.h:52
SGObservableS * get_observable()
Definition: Signal.h:61
Multiclass Labels for multi-class classification.
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
Definition: Machine.cpp:215
virtual void on_next()
Definition: Machine.h:411
CSignal * get_global_signal()
Definition: init.cpp:202
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:114
virtual ~CMachine()
Definition: Machine.cpp:38
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:124
rxcpp::subscription connect_to_signal_handler()
Definition: Machine.cpp:280
double float64_t
Definition: common.h:60
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: Machine.cpp:221
virtual void data_unlock()
Definition: Machine.cpp:150
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:119
virtual CLabels * get_labels()
Definition: Machine.cpp:83
float64_t get_max_train_time()
Definition: Machine.cpp:94
ESolverType get_solver_type()
Definition: Machine.cpp:109
virtual CLatentLabels * apply_latent(CFeatures *data=NULL)
Definition: Machine.cpp:239
virtual EMachineType get_classifier_type()
Definition: Machine.cpp:99
virtual EProblemType get_machine_problem_type() const
Definition: Machine.h:311
virtual CRegressionLabels * apply_locked_regression(SGVector< index_t > indices)
Definition: Machine.cpp:252
virtual void store_model_features()
Definition: Machine.h:378
virtual bool supports_locking() const
Definition: Machine.h:305
virtual CMulticlassLabels * apply_locked_multiclass(SGVector< index_t > indices)
Definition: Machine.cpp:259
#define SG_UNREF(x)
Definition: SGObject.h:53
virtual CStructuredLabels * apply_structured(CFeatures *data=NULL)
Definition: Machine.cpp:233
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual void post_lock(CLabels *labs, CFeatures *features)
Definition: Machine.h:299
int machine_int_t
Definition: common.h:69
virtual bool is_label_valid(CLabels *lab) const
Definition: Machine.h:391
The class Features is the base class of all feature objects.
Definition: Features.h:69
virtual CBinaryLabels * apply_locked_binary(SGVector< index_t > indices)
Definition: Machine.cpp:245
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:43
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: Machine.cpp:227
virtual bool train_require_labels() const
Definition: Machine.h:397
#define SG_ADD(...)
Definition: SGObject.h:93
virtual CLatentLabels * apply_locked_latent(SGVector< index_t > indices)
Definition: Machine.cpp:273
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:72
abstract class for latent labels As latent labels always depends on the given application, this class only defines the API that the user has to implement for latent labels.
Definition: LatentLabels.h:26
virtual void ensure_valid(const char *context=NULL)=0
void set_solver_type(ESolverType st)
Definition: Machine.cpp:104
virtual void on_pause()
Definition: Machine.h:418
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:159

SHOGUN Machine Learning Toolbox - Documentation