SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RBM.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2014, Shogun Toolbox Foundation
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7 
8  * 1. Redistributions of source code must retain the above copyright notice,
9  * this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from this
17  * software without specific prior written permission.
18 
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  * POSSIBILITY OF SUCH DAMAGE.
30  *
31  * Written (W) 2014 Khaled Nasr
32  */
33 
34 #ifndef __RBM_H__
35 #define __RBM_H__
36 
37 #include <shogun/lib/config.h>
38 #ifdef HAVE_EIGEN3
39 
40 #include <shogun/lib/common.h>
41 #include <shogun/base/SGObject.h>
42 #include <shogun/lib/SGMatrix.h>
43 #include <shogun/lib/SGVector.h>
46 
47 namespace shogun
48 {
50 {
53 };
54 
56 {
60 };
61 
123 class CRBM : public CSGObject
124 {
125 friend class CDeepBeliefNetwork;
126 
127 public:
129  CRBM();
130 
136  CRBM(int32_t num_hidden);
137 
144  CRBM(int32_t num_hidden, int32_t num_visible,
145  ERBMVisibleUnitType visible_unit_type = RBMVUT_BINARY);
146 
147  virtual ~CRBM();
148 
154  virtual void add_visible_group(int32_t num_units, ERBMVisibleUnitType unit_type);
155 
162  virtual void initialize(float64_t sigma=0.01);
163 
168  virtual void set_batch_size(int32_t batch_size);
169 
175  virtual void train(CDenseFeatures<float64_t>* features);
176 
185  virtual void sample(int32_t num_gibbs_steps=1, int32_t batch_size=1);
186 
199  int32_t V,
200  int32_t num_gibbs_steps=1, int32_t batch_size=1);
201 
211  virtual void sample_with_evidence(
212  int32_t E, CDenseFeatures<float64_t>* evidence,
213  int32_t num_gibbs_steps=1);
214 
228  int32_t V,
229  int32_t E, CDenseFeatures<float64_t>* evidence,
230  int32_t num_gibbs_steps=1);
231 
235  virtual void reset_chain();
236 
255  virtual float64_t free_energy(SGMatrix<float64_t> visible,
257 
272  virtual void free_energy_gradients(SGMatrix<float64_t> visible,
273  SGVector<float64_t> gradients,
274  bool positive_phase = true,
275  SGMatrix<float64_t> hidden_mean_given_visible = SGMatrix<float64_t>());
276 
283  virtual void contrastive_divergence(SGMatrix<float64_t> visible_batch,
284  SGVector<float64_t> gradients);
285 
295 
310 
313  {
315  }
316 
319 
327 
335 
343 
345  virtual int32_t get_num_parameters() { return m_num_params; }
346 
347  virtual const char* get_name() const { return "RBM"; }
348 
349 protected:
351  virtual void mean_hidden(SGMatrix<float64_t> visible, SGMatrix<float64_t> result);
352 
354  virtual void mean_visible(SGMatrix<float64_t> hidden, SGMatrix<float64_t> result);
355 
357  virtual void sample_hidden(SGMatrix<float64_t> mean, SGMatrix<float64_t> result);
358 
360  virtual void sample_visible(SGMatrix<float64_t> mean, SGMatrix<float64_t> result);
361 
363  virtual void sample_visible(int32_t index,
365 
366 private:
367  void init();
368 
369 public:
373  int32_t cd_num_steps;
374 
378 
384 
387 
390 
395 
398 
402  int32_t max_num_epochs;
403 
409 
412 
419 
429 
432 
435 
436 protected:
438  int32_t m_num_hidden;
439 
441  int32_t m_num_visible;
442 
444  int32_t m_batch_size;
445 
448 
451 
454 
457 
459  int32_t m_num_params;
460 
463 };
464 
465 }
466 #endif
467 #endif

SHOGUN Machine Learning Toolbox - Documentation