SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
DeepBeliefNetwork.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2014, Shogun Toolbox Foundation
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7 
8  * 1. Redistributions of source code must retain the above copyright notice,
9  * this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from this
17  * software without specific prior written permission.
18 
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  * POSSIBILITY OF SUCH DAMAGE.
30  *
31  * Written (W) 2014 Khaled Nasr
32  */
33 
34 #ifndef __DEEPBELIEFNETWORK_H__
35 #define __DEEPBELIEFNETWORK_H__
36 
37 #include <shogun/lib/config.h>
38 #ifdef HAVE_EIGEN3
39 
40 #include <shogun/lib/common.h>
41 #include <shogun/base/SGObject.h>
42 #include <shogun/neuralnets/RBM.h>
43 #include <lib/SGMatrixList.h>
44 
45 namespace shogun
46 {
47 template <class T> class SGVector;
48 template <class T> class SGMatrix;
49 template <class T> class SGMatrixList;
50 template <class T> class CDenseFeatures;
51 template <class T> class CDynamicArray;
52 class CDynamicObjectArray;
53 class CNeuralNetwork;
54 class CNeuralLayer;
55 
92 {
93 public:
96 
102  CDeepBeliefNetwork(int32_t num_visible_units,
103  ERBMVisibleUnitType unit_type = RBMVUT_BINARY);
104 
105  virtual ~CDeepBeliefNetwork();
106 
112  virtual void add_hidden_layer(int32_t num_units);
113 
119  virtual void initialize(float64_t sigma = 0.01);
120 
125  virtual void set_batch_size(int32_t batch_size);
126 
132  virtual void pre_train(CDenseFeatures<float64_t>* features);
133 
140  virtual void pre_train(int32_t index, CDenseFeatures<float64_t>* features);
141 
148  virtual void train(CDenseFeatures<float64_t>* features);
149 
163  CDenseFeatures<float64_t>* features, int32_t i=-1);
164 
176  int32_t num_gibbs_steps=1, int32_t batch_size=1);
177 
179  virtual void reset_chain();
180 
194  CNeuralLayer* output_layer=NULL, float64_t sigma = 0.01);
195 
202  virtual SGMatrix<float64_t> get_weights(int32_t index,
204 
211  virtual SGVector<float64_t> get_biases(int32_t index,
213 
214  virtual const char* get_name() const { return "DeepBeliefNetwork"; }
215 
216 protected:
218  virtual void down_step(int32_t index, SGVector<float64_t> params,
220  bool sample_states = true);
221 
223  virtual void up_step(int32_t index, SGVector<float64_t> params,
225  bool sample_states = true);
226 
228  virtual void wake_sleep(SGMatrix<float64_t> data,
229  CRBM* top_rbm,
230  SGMatrixList<float64_t> sleep_states,
231  SGMatrixList<float64_t> wake_states,
232  SGMatrixList<float64_t> psleep_states,
233  SGMatrixList<float64_t> pwake_states,
234  SGVector<float64_t> gen_params,
235  SGVector<float64_t> rec_params,
236  SGVector<float64_t> gen_gradients,
237  SGVector<float64_t> rec_gradients);
238 
239 private:
240  void init();
241 
242 public:
247 
252 
257 
262 
267 
272 
277 
282 
287 
292 
297 
302 
306  int32_t cd_num_steps;
307 
312 
316  int32_t max_num_epochs;
317 
323 
326 
333 
343 
344 protected:
347 
349  int32_t m_num_layers;
350 
353 
356 
358  int32_t m_batch_size;
359 
362 
364  int32_t m_num_params;
365 
368 
373 
377 };
378 
379 }
380 #endif
381 #endif

SHOGUN Machine Learning Toolbox - Documentation