SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
KMeansMiniBatch.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2014 Parijat Mazumdar
8  * Written (W) 2016 Saurabh Mahindre
9  */
10 
11 #ifndef _MBKMEANS_H__
12 #define _MBKMEANS_H__
13 
14 #include <shogun/lib/config.h>
15 
16 #include <shogun/lib/common.h>
17 #include <shogun/io/SGIO.h>
21 
22 namespace shogun
23 {
24 class CKMeansBase;
25 
28 {
29  public:
32 
39  CKMeansMiniBatch(int32_t k, CDistance* d, bool kmeanspp=false);
40 
46  CKMeansMiniBatch(int32_t k_i, CDistance* d_i, SGMatrix<float64_t> centers_i);
47 
48  virtual ~CKMeansMiniBatch();
49 
51  virtual const char* get_name() const { return "KMeansMiniBatch"; }
52 
57  void set_batch_size(int32_t b);
58 
63  int32_t get_batch_size() const;
64 
70  void set_mb_iter(int32_t t);
71 
76  int32_t get_mb_iter() const;
77 
83  void set_mb_params(int32_t b, int32_t t);
84 
85  protected:
86 
95  virtual bool train_machine(CFeatures* data=NULL);
96 
99  void minibatch_KMeans();
100 
101  private:
102 
103  void init_mb_params();
104 
105  /* choose b integers between 0 and num-1
106  *
107  */
108  SGVector<int32_t> mbchoose_rand(int32_t b, int32_t num);
109 
110  protected:
111 
113  int32_t batch_size;
114 
116  int32_t minib_iter;
117 
118 };
119 }
120 #endif
int32_t get_batch_size() const
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
void set_batch_size(int32_t b)
virtual bool train_machine(CFeatures *data=NULL)
void set_mb_params(int32_t b, int32_t t)
int32_t get_mb_iter() const
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
SGMatrix< float64_t > kmeanspp()
Definition: KMeansBase.cpp:276
virtual const char * get_name() const

SHOGUN Machine Learning Toolbox - Documentation