SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
BeliefPropagation.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
11 #ifndef __BELIEF_PROPAGATION_H__
12 #define __BELIEF_PROPAGATION_H__
13 
14 #include <shogun/lib/SGVector.h>
17 
18 #include <vector>
19 #include <set>
20 
21 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11)
22  #include <unordered_map>
23 #else
24  #include <tr1/unordered_map>
25 #endif
26 
27 #ifndef DOXYGEN_SHOULD_SKIP_THIS
28 
29 namespace shogun
30 {
31 #define IGNORE_IN_CLASSLIST
32 
33 enum ENodeType
34 {
35  VAR_NODE = 0,
36  FAC_NODE = 1
37 };
38 
39 enum EEdgeType
40 {
41  VAR_TO_FAC = 0,
42  FAC_TO_VAR = 1
43 };
44 
45 struct GraphNode
46 {
47  int32_t node_id;
48  ENodeType node_type; // 1 var, 0 factor
49  int32_t parent; // where came from
50 
51  GraphNode(int32_t id, ENodeType type, int32_t pa)
52  : node_id(id), node_type(type), parent(pa) { }
53  ~GraphNode() { }
54 };
55 
56 struct MessageEdge
57 {
58  EEdgeType mtype; // 1 var_to_factor, 0 factor_to_var
59  int32_t child;
60  int32_t parent;
61 
62  MessageEdge(EEdgeType type, int32_t ch, int32_t pa)
63  : mtype(type), child(ch), parent(pa) { }
64 
65  ~MessageEdge() { }
66 
67  inline int32_t get_var_node()
68  {
69  return mtype == VAR_TO_FAC ? child : parent;
70  }
71 
72  inline int32_t get_factor_node()
73  {
74  return mtype == VAR_TO_FAC ? parent : child;
75  }
76 };
77 
79 IGNORE_IN_CLASSLIST class CBeliefPropagation : public CMAPInferImpl
80 {
81 public:
82  CBeliefPropagation();
83  CBeliefPropagation(CFactorGraph* fg);
84 
85  virtual ~CBeliefPropagation();
86 
88  virtual const char* get_name() const { return "BeliefPropagation"; }
89 
90  virtual float64_t inference(SGVector<int32_t> assignment);
91 
92 protected:
93  float64_t m_map_energy;
94 };
95 
104 IGNORE_IN_CLASSLIST class CTreeMaxProduct : public CBeliefPropagation
105 {
106 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11)
107  typedef std::unordered_map<uint32_t, uint32_t> msg_map_type;
108  typedef std::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type;
109  typedef std::unordered_multimap<int32_t, int32_t> var_factor_map_type;
110 #else
111  typedef std::tr1::unordered_map<uint32_t, uint32_t> msg_map_type;
112  typedef std::tr1::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type;
113  typedef std::tr1::unordered_multimap<int32_t, int32_t> var_factor_map_type;
114 #endif
115 
116 public:
117  CTreeMaxProduct();
118  CTreeMaxProduct(CFactorGraph* fg);
119 
120  virtual ~CTreeMaxProduct();
121 
123  virtual const char* get_name() const { return "TreeMaxProduct"; }
124 
125  virtual float64_t inference(SGVector<int32_t> assignment);
126 
127 protected:
128  void bottom_up_pass();
129  void top_down_pass();
130  void get_message_order(std::vector<MessageEdge*>& order, std::vector<bool>& is_root) const;
131 
132 private:
133  void init();
134 
135 private:
136  std::vector<MessageEdge*> m_msg_order;
137  std::vector<bool> m_is_root;
138  std::vector< std::vector<float64_t> > m_fw_msgs;
139  std::vector< std::vector<float64_t> > m_bw_msgs;
140  std::vector<int32_t> m_states;
141 
142  msg_map_type m_msg_map_var;
143  msg_map_type m_msg_map_fac;
144  msgset_map_type m_msgset_map_var;
145 };
146 
147 }
148 
149 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
150 
151 #endif

SHOGUN Machine Learning Toolbox - Documentation