SHOGUN  6.1.3
DataManager.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 - 2017 Soumyajit De
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 #include <memory>
32 #include <shogun/io/SGIO.h>
39 
40 using namespace shogun;
41 using namespace internal;
42 
44 {
45  SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions);
46  fetchers.resize(num_distributions);
47  std::fill(fetchers.begin(), fetchers.end(), nullptr);
48 
49  train_test_mode=default_train_test_mode;
50  train_mode=default_train_mode;
51  train_test_ratio=default_train_test_ratio;
52  cross_validation_mode=default_cross_validation_mode;
53 }
54 
56 {
57 }
58 
60 {
61  SG_SDEBUG("Entering!\n");
62  index_t n=0;
63  typedef const std::unique_ptr<DataFetcher> fetcher_type;
64  if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { return f->m_num_samples==0; }))
65  SG_SERROR("number of samples from all the distributions are not set!")
66  else
67  std::for_each(fetchers.begin(), fetchers.end(), [&n](fetcher_type& f) { n+=f->m_num_samples; });
68  SG_SDEBUG("Leaving!\n");
69  return n;
70 }
71 
73 {
74  SG_SDEBUG("Entering!\n");
75  index_t min_blocksize=0;
76  typedef const std::unique_ptr<DataFetcher> fetcher_type;
77  if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { return f->m_num_samples==0; }))
78  SG_SERROR("number of samples from all the distributions are not set!")
79  else
80  {
81  index_t divisor=0;
82  for (size_t i=0; i<fetchers.size(); ++i)
83  divisor=CMath::gcd(divisor, fetchers[i]->m_num_samples);
84  min_blocksize=get_num_samples()/divisor;
85  }
86  SG_SDEBUG("min blocksize is %d!", min_blocksize);
87  SG_SDEBUG("Leaving!\n");
88  return min_blocksize;
89 }
90 
92 {
93  SG_SDEBUG("Entering!\n");
94  auto n=get_num_samples();
95 
96  REQUIRE(n>0,
97  "Total number of samples is 0! Please set the number of samples!\n");
98  REQUIRE(blocksize>0 && blocksize<=n,
99  "The blocksize has to be within [0, %d], given = %d!\n",
100  n, blocksize);
101  REQUIRE(n%blocksize==0,
102  "Total number of samples (%d) has to be divisble by the blocksize (%d)!\n",
103  n, blocksize);
104 
105  for (size_t i=0; i<fetchers.size(); ++i)
106  {
107  index_t m=fetchers[i]->m_num_samples;
108  REQUIRE((blocksize*m)%n==0,
109  "Blocksize (%d) cannot be even distributed with a ratio of %f!\n",
110  blocksize, m/n);
111  fetchers[i]->fetch_blockwise().with_blocksize(blocksize*m/n);
112  SG_SDEBUG("block[%d].size = ", i, blocksize*m/n);
113  }
114  SG_SDEBUG("Leaving!\n");
115 }
116 
118 {
119  SG_SDEBUG("Entering!\n");
120  REQUIRE(num_blocks_per_burst>0,
121  "Number of blocks per burst (%d) has to be greater than 0!\n",
122  num_blocks_per_burst);
123 
124  index_t blocksize=0;
125  typedef std::unique_ptr<DataFetcher> fetcher_type;
126  std::for_each(fetchers.begin(), fetchers.end(), [&blocksize](fetcher_type& f)
127  {
128  blocksize+=f->m_block_details.m_blocksize;
129  });
130  REQUIRE(blocksize>0,
131  "Blocksizes are not set!\n");
132 
133  index_t max_num_blocks_per_burst=get_num_samples()/blocksize;
134  if (num_blocks_per_burst>max_num_blocks_per_burst)
135  {
136  SG_SINFO("There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
137  SG_SINFO("Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
138  num_blocks_per_burst=max_num_blocks_per_burst;
139  }
140 
141  for (size_t i=0; i<fetchers.size(); ++i)
142  fetchers[i]->fetch_blockwise().with_num_blocks_per_burst(num_blocks_per_burst);
143  SG_SDEBUG("Leaving!\n");
144 }
145 
147 {
148  SG_SDEBUG("Entering!\n");
149  REQUIRE(i<(int64_t)fetchers.size(),
150  "Value of i (%d) should be between 0 and %d, inclusive!",
151  i, fetchers.size()-1);
152  SG_SDEBUG("Leaving!\n");
153  return InitPerFeature(fetchers[i]);
154 }
155 
157 {
158  SG_SDEBUG("Entering!\n");
159  REQUIRE(i<(int64_t)fetchers.size(),
160  "Value of i (%d) should be between 0 and %d, inclusive!",
161  i, fetchers.size()-1);
162  SG_SDEBUG("Leaving!\n");
163  if (fetchers[i]!=nullptr)
164  return fetchers[i]->m_samples;
165  else
166  return nullptr;
167 }
168 
170 {
171  SG_SDEBUG("Entering!\n");
172  REQUIRE(i<(int64_t)fetchers.size(),
173  "Value of i (%d) should be between 0 and %d, inclusive!",
174  i, fetchers.size()-1);
175  SG_SDEBUG("Leaving!\n");
176  return fetchers[i]->m_num_samples;
177 }
178 
180 {
181  SG_SDEBUG("Entering!\n");
182  REQUIRE(i<(int64_t)fetchers.size(),
183  "Value of i (%d) should be between 0 and %d, inclusive!",
184  i, fetchers.size()-1);
185  SG_SDEBUG("Leaving!\n");
186  if (fetchers[i]!=nullptr)
187  return fetchers[i]->get_num_samples();
188  else
189  return 0;
190 }
191 
193 {
194  SG_SDEBUG("Entering!\n");
195  REQUIRE(i<(int64_t)fetchers.size(),
196  "Value of i (%d) should be between 0 and %d, inclusive!",
197  i, fetchers.size()-1);
198  SG_SDEBUG("Leaving!\n");
199  if (fetchers[i]!=nullptr)
200  return fetchers[i]->m_block_details.m_blocksize;
201  else
202  return 0;
203 }
204 
205 void DataManager::set_blockwise(bool blockwise)
206 {
207  SG_SDEBUG("Entering!\n");
208  for (size_t i=0; i<fetchers.size(); ++i)
209  fetchers[i]->set_blockwise(blockwise);
210  SG_SDEBUG("Leaving!\n");
211 }
212 
213 const bool DataManager::is_blockwise() const
214 {
215  SG_SDEBUG("Entering!\n");
216  bool blockwise=true;
217  for (size_t i=0; i<fetchers.size(); ++i)
218  blockwise&=!fetchers[i]->m_block_details.m_full_data;
219  SG_SDEBUG("Leaving!\n");
220  return blockwise;
221 }
222 
223 void DataManager::set_train_test_mode(bool on)
224 {
225  train_test_mode=on;
226  if (!train_test_mode)
227  {
228  train_mode=default_train_mode;
229  train_test_ratio=default_train_test_ratio;
230  cross_validation_mode=default_cross_validation_mode;
231  }
232  REQUIRE(fetchers.size()>0, "Features are not set!");
233  typedef std::unique_ptr<DataFetcher> fetcher_type;
234  std::for_each(fetchers.begin(), fetchers.end(), [this, on](fetcher_type& f)
235  {
236  f->set_train_test_mode(on);
237  if (on)
238  {
239  f->set_train_mode(train_mode);
240  f->set_train_test_ratio(train_test_ratio);
241  }
242  });
243 }
244 
245 bool DataManager::is_train_test_mode() const
246 {
247  return train_test_mode;
248 }
249 
250 void DataManager::set_train_mode(bool on)
251 {
252  if (train_test_mode)
253  train_mode=on;
254  else
255  {
256  SG_SERROR("Train mode cannot be used without turning on Train/Test mode first!"
257  "Please call set_train_test_mode(True) before using this method!\n");
258  }
259 }
260 
261 bool DataManager::is_train_mode() const
262 {
263  return train_mode;
264 }
265 
266 void DataManager::set_cross_validation_mode(bool on)
267 {
268  if (train_test_mode)
269  cross_validation_mode=on;
270  else
271  {
272  SG_SERROR("Cross-validation mode cannot be used without turning on Train/Test mode first!"
273  "Please call set_train_test_mode(True) before using this method!\n");
274  }
275 }
276 
277 bool DataManager::is_cross_validation_mode() const
278 {
279  return cross_validation_mode;
280 }
281 
282 void DataManager::set_train_test_ratio(float64_t ratio)
283 {
284  if (train_test_mode)
285  train_test_ratio=ratio;
286  else
287  {
288  SG_SERROR("Train-test ratio cannot be set without turning on Train/Test mode first!"
289  "Please call set_train_test_mode(True) before using this method!\n");
290  }
291 }
292 
293 float64_t DataManager::get_train_test_ratio() const
294 {
295  return train_test_ratio;
296 }
297 
298 index_t DataManager::get_num_folds() const
299 {
300  return ceil(get_train_test_ratio())+1;
301 }
302 
303 void DataManager::shuffle_features()
304 {
305  SG_SDEBUG("Entering!\n");
306  REQUIRE(fetchers.size()>0, "Features are not set!");
307  typedef std::unique_ptr<DataFetcher> fetcher_type;
308  std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->shuffle_features(); });
309  SG_SDEBUG("Leaving!\n");
310 }
311 
312 void DataManager::unshuffle_features()
313 {
314  SG_SDEBUG("Entering!\n");
315  REQUIRE(fetchers.size()>0, "Features are not set!");
316  typedef std::unique_ptr<DataFetcher> fetcher_type;
317  std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->unshuffle_features(); });
318  SG_SDEBUG("Leaving!\n");
319 }
320 
321 void DataManager::init_active_subset()
322 {
323  SG_SDEBUG("Entering!\n");
324 
325  REQUIRE(train_test_mode,
326  "Train-test subset cannot be used without turning on Train/Test mode first!"
327  "Please call set_train_test_mode(True) before using this method!\n");
328  REQUIRE(fetchers.size()>0, "Features are not set!");
329 
330  typedef std::unique_ptr<DataFetcher> fetcher_type;
331  std::for_each(fetchers.begin(), fetchers.end(), [this](fetcher_type& f)
332  {
333  f->set_train_mode(train_mode);
334  f->set_train_test_ratio(train_test_ratio);
335  f->init_active_subset();
336  });
337  SG_SDEBUG("Leaving!\n");
338 }
339 
340 void DataManager::use_fold(index_t idx)
341 {
342  SG_SDEBUG("Entering!\n");
343 
344  REQUIRE(train_test_mode,
345  "Fold subset cannot be used without turning on Train/Test mode first!"
346  "Please call set_train_test_mode(True) before using this method!\n");
347  REQUIRE(fetchers.size()>0, "Features are not set!");
348  REQUIRE(idx>=0, "Fold index has to be in [0, %d]!", get_num_folds()-1);
349  REQUIRE(idx<get_num_folds(), "Fold index has to be in [0, %d]!", get_num_folds()-1);
350 
351  typedef std::unique_ptr<DataFetcher> fetcher_type;
352  std::for_each(fetchers.begin(), fetchers.end(), [this, idx](fetcher_type& f)
353  {
354  f->set_train_mode(train_mode);
355  f->set_train_test_ratio(train_test_ratio);
356  f->use_fold(idx);
357  });
358  SG_SDEBUG("Leaving!\n");
359 }
360 
361 void DataManager::start()
362 {
363  SG_SDEBUG("Entering!\n");
364  REQUIRE(fetchers.size()>0, "Features are not set!");
365 
366  if (train_test_mode && !cross_validation_mode)
367  init_active_subset();
368 
369  typedef std::unique_ptr<DataFetcher> fetcher_type;
370  std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->start(); });
371  SG_SDEBUG("Leaving!\n");
372 }
373 
374 NextSamples DataManager::next()
375 {
376  SG_SDEBUG("Entering!\n");
377 
378  // sets the number of feature objects (number of distributions)
379  NextSamples next_samples(fetchers.size());
380 
381  // fetch a number of blocks (per burst) from each distribution
382  for (size_t i=0; i<fetchers.size(); ++i)
383  {
384  auto feats=fetchers[i]->next();
385  if (feats!=nullptr)
386  {
387  auto blocksize=fetchers[i]->m_block_details.m_blocksize;
388  auto num_blocks_curr_burst=feats->get_num_vectors()/blocksize;
389 
390  // use same number of blocks from all the distributions
391  if (next_samples.m_num_blocks==0)
392  next_samples.m_num_blocks=num_blocks_curr_burst;
393  else
394  ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst);
395 
396  next_samples[i]=Block::create_blocks(feats, num_blocks_curr_burst, blocksize);
397  SG_UNREF(feats);
398  }
399  }
400  SG_SDEBUG("Leaving!\n");
401  return next_samples;
402 }
403 
404 void DataManager::end()
405 {
406  SG_SDEBUG("Entering!\n");
407  REQUIRE(fetchers.size()>0, "Features are not set!");
408  typedef std::unique_ptr<DataFetcher> fetcher_type;
409  std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->end(); });
410  SG_SDEBUG("Leaving!\n");
411 }
412 
413 void DataManager::reset()
414 {
415  SG_SDEBUG("Entering!\n");
416  REQUIRE(fetchers.size()>0, "Features are not set!");
417  typedef std::unique_ptr<DataFetcher> fetcher_type;
418  std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); });
419  SG_SDEBUG("Leaving!\n");
420 }
static int32_t gcd(int32_t a, int32_t b)
Definition: Math.h:937
static std::vector< Block > create_blocks(CFeatures *feats, index_t num_blocks, index_t size)
Definition: Block.cpp:79
index_t get_num_samples() const
Definition: DataManager.cpp:59
DataManager(index_t num_distributions)
Definition: DataManager.cpp:43
index_t & num_samples_at(index_t i)
int32_t index_t
Definition: common.h:72
void set_blocksize(index_t blocksize)
Definition: DataManager.cpp:91
InitPerFeature samples_at(index_t i)
#define REQUIRE(x,...)
Definition: SGIO.h:181
#define ASSERT(x)
Definition: SGIO.h:176
double float64_t
Definition: common.h:60
void set_num_blocks_per_burst(index_t num_blocks_per_burst)
#define SG_UNREF(x)
Definition: SGObject.h:53
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:153
The class Features is the base class of all feature objects.
Definition: Features.h:69
#define SG_SERROR(...)
Definition: SGIO.h:164
#define SG_SINFO(...)
Definition: SGIO.h:158
index_t get_min_blocksize() const
Definition: DataManager.cpp:72
class NextSamples is the return type for next() call in DataManager. If there are no more samples (fr...
Definition: NextSamples.h:68
const index_t blocksize_at(index_t i) const

SHOGUN Machine Learning Toolbox - Documentation