SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
BeliefPropagation.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
13 #include <shogun/io/SGIO.h>
14 #include <numeric>
15 #include <algorithm>
16 #include <functional>
17 #include <stack>
18 
19 using namespace shogun;
20 
21 CBeliefPropagation::CBeliefPropagation()
22  : CMAPInferImpl()
23 {
24  SG_UNSTABLE("CBeliefPropagation::CBeliefPropagation()", "\n");
25 }
26 
27 CBeliefPropagation::CBeliefPropagation(CFactorGraph* fg)
28  : CMAPInferImpl(fg)
29 {
30 }
31 
32 CBeliefPropagation::~CBeliefPropagation()
33 {
34 }
35 
36 float64_t CBeliefPropagation::inference(SGVector<int32_t> assignment)
37 {
38  SG_ERROR("%s::inference(): please use TreeMaxProduct or LoopyMaxProduct!\n", get_name());
39  return 0;
40 }
41 
42 // -----------------------------------------------------------------
43 
44 CTreeMaxProduct::CTreeMaxProduct()
45  : CBeliefPropagation()
46 {
47  SG_UNSTABLE("CTreeMaxProduct::CTreeMaxProduct()", "\n");
48 
49  init();
50 }
51 
52 CTreeMaxProduct::CTreeMaxProduct(CFactorGraph* fg)
53  : CBeliefPropagation(fg)
54 {
55  ASSERT(m_fg != NULL);
56 
57  init();
58 
59  CDisjointSet* dset = m_fg->get_disjoint_set();
60  bool is_connected = dset->get_connected();
61  SG_UNREF(dset);
62 
63  if (!is_connected)
64  m_fg->connect_components();
65 
66  get_message_order(m_msg_order, m_is_root);
67 
68  // calculate lookup tables for forward messages
69  // a key is unique because a tree has only one root
70  // a var or a factor has only one edge towards root
71  for (uint32_t mi = 0; mi < m_msg_order.size(); mi++)
72  {
73  if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var_to_factor
74  {
75  // <var_id, msg_id>
76  m_msg_map_var[m_msg_order[mi]->child] = mi;
77  }
78  else // factor_to_var
79  {
80  // <fac_id, msg_id>
81  m_msg_map_fac[m_msg_order[mi]->child] = mi;
82  // collect incoming msgs for each var_id
83  m_msgset_map_var[m_msg_order[mi]->parent].insert(mi);
84  }
85  }
86 
87 }
88 
89 CTreeMaxProduct::~CTreeMaxProduct()
90 {
91  if (!m_msg_order.empty())
92  {
93  for (std::vector<MessageEdge*>::iterator it = m_msg_order.begin(); it != m_msg_order.end(); ++it)
94  delete *it;
95  }
96 }
97 
98 void CTreeMaxProduct::init()
99 {
100  m_msg_order = std::vector<MessageEdge*>(m_fg->get_num_edges(), (MessageEdge*) NULL);
101  m_is_root = std::vector<bool>(m_fg->get_cardinalities().size(), false);
102  m_fw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
103  std::vector<float64_t>());
104  m_bw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
105  std::vector<float64_t>());
106  m_states = std::vector<int32_t>(m_fg->get_cardinalities().size(), 0);
107 
108  m_msg_map_var = msg_map_type();
109  m_msg_map_fac = msg_map_type();
110  m_msgset_map_var = msgset_map_type();
111 }
112 
113 void CTreeMaxProduct::get_message_order(std::vector<MessageEdge*>& order,
114  std::vector<bool>& is_root) const
115 {
116  ASSERT(m_fg->is_acyclic_graph());
117 
118  // 1) pick up roots according to union process of disjoint sets
119  CDisjointSet* dset = m_fg->get_disjoint_set();
120  if (!dset->get_connected())
121  {
122  SG_UNREF(dset);
123  SG_ERROR("%s::get_root_indicators(): run connect_components() first!\n", get_name());
124  }
125 
126  int32_t num_vars = m_fg->get_cardinalities().size();
127  if (is_root.size() != (uint32_t)num_vars)
128  is_root.resize(num_vars);
129 
130  std::fill(is_root.begin(), is_root.end(), false);
131 
132  for (int32_t vi = 0; vi < num_vars; vi++)
133  is_root[dset->find_set(vi)] = true;
134 
135  SG_UNREF(dset);
136  ASSERT(std::accumulate(is_root.begin(), is_root.end(), 0) >= 1);
137 
138  // 2) caculate message order
139  // <var_id, fac_id>
140  var_factor_map_type vf_map;
141  CDynamicObjectArray* facs = m_fg->get_factors();
142 
143  for (int32_t fi = 0; fi < facs->get_num_elements(); ++fi)
144  {
145  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fi));
146  SGVector<int32_t> vars = fac->get_variables();
147  for (int32_t vi = 0; vi < vars.size(); vi++)
148  vf_map.insert(var_factor_map_type::value_type(vars[vi], fi));
149 
150  SG_UNREF(fac);
151  }
152 
153  std::stack<GraphNode*> node_stack;
154  // init node_stack with root nodes
155  for (uint32_t ni = 0; ni < is_root.size(); ni++)
156  {
157  if (is_root[ni])
158  {
159  // node_id = ni, node_type = variable, parent = none
160  node_stack.push(new GraphNode(ni, VAR_NODE, -1));
161  }
162  }
163 
164  if (order.size() != (uint32_t)(m_fg->get_num_edges()))
165  order.resize(m_fg->get_num_edges());
166 
167  // find reverse order
168  int32_t eid = m_fg->get_num_edges() - 1;
169  while (!node_stack.empty())
170  {
171  GraphNode* node = node_stack.top();
172  node_stack.pop();
173 
174  if (node->node_type == VAR_NODE) // child: factor -> parent: var
175  {
176  typedef var_factor_map_type::const_iterator const_iter;
177  std::pair<const_iter, const_iter> adj_factors = vf_map.equal_range(node->node_id);
178  for (const_iter mi = adj_factors.first; mi != adj_factors.second; ++mi)
179  {
180  int32_t adj_factor_id = mi->second;
181  if (adj_factor_id == node->parent)
182  continue;
183 
184  order[eid--] = new MessageEdge(FAC_TO_VAR, adj_factor_id, node->node_id);
185  // add factor node to node_stack
186  node_stack.push(new GraphNode(adj_factor_id, FAC_NODE, node->node_id));
187  }
188  }
189  else // child: var -> parent: factor
190  {
191  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(node->node_id));
192  SGVector<int32_t> vars = fac->get_variables();
193  SG_UNREF(fac);
194 
195  for (int32_t vi = 0; vi < vars.size(); vi++)
196  {
197  if (vars[vi] == node->parent)
198  continue;
199 
200  order[eid--] = new MessageEdge(VAR_TO_FAC, vars[vi], node->node_id);
201  // add variable node to node_stack
202  node_stack.push(new GraphNode(vars[vi], VAR_NODE, node->node_id));
203  }
204  }
205 
206  delete node;
207  }
208 
209  SG_UNREF(facs);
210 }
211 
212 float64_t CTreeMaxProduct::inference(SGVector<int32_t> assignment)
213 {
214  REQUIRE(assignment.size() == m_fg->get_cardinalities().size(),
215  "%s::inference(): the output assignment should be prepared as"
216  "the same size as variables!\n", get_name());
217 
218  bottom_up_pass();
219  top_down_pass();
220 
221  for (int32_t vi = 0; vi < assignment.size(); vi++)
222  assignment[vi] = m_states[vi];
223 
224  SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", m_fg->evaluate_energy(assignment));
225  SG_DEBUG("minimized energy = %f\n", -m_map_energy);
226 
227  return -m_map_energy;
228 }
229 
230 void CTreeMaxProduct::bottom_up_pass()
231 {
232  SG_DEBUG("\n***enter bottom_up_pass().\n");
233  CDynamicObjectArray* facs = m_fg->get_factors();
234  SGVector<int32_t> cards = m_fg->get_cardinalities();
235 
236  // init forward msgs to 0
237  m_fw_msgs.resize(m_msg_order.size());
238  for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
239  {
240  // msg size is determined by var node of msg edge
241  m_fw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
242  std::fill(m_fw_msgs[mi].begin(), m_fw_msgs[mi].end(), 0);
243  }
244 
245  // pass msgs along the order up to root
246  // if var -> factor
247  // compute q_v2f
248  // else factor -> var
249  // compute r_f2v
250  // where q_v2f and r_f2v are beliefs of the edge collecting from neighborhoods
251  // by one end, which will be sent to another end, read Eq.(3.19), Eq.(3.20)
252  // on [Nowozin et al. 2011] for more detail.
253  for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
254  {
255  SG_DEBUG("mi = %d, mtype: %d %d -> %d\n", mi,
256  m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
257 
258  if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var -> factor
259  {
260  uint32_t var_id = m_msg_order[mi]->child;
261  const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
262 
263  // q_v2f = sum(r_f2v), i.e. sum all incoming f2v msgs
264  for (std::set<uint32_t>::const_iterator cit = msgset_var.begin(); cit != msgset_var.end(); cit++)
265  {
266  std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
267  m_fw_msgs[mi].begin(),
268  m_fw_msgs[mi].begin(),
269  std::plus<float64_t>());
270  }
271  }
272  else // factor -> var
273  {
274  int32_t fac_id = m_msg_order[mi]->child;
275  int32_t var_id = m_msg_order[mi]->parent;
276 
277  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
278  CTableFactorType* ftype = fac->get_factor_type();
279  SGVector<int32_t> fvars = fac->get_variables();
280  SGVector<float64_t> fenrgs = fac->get_energies();
281  SG_UNREF(fac);
282 
283  // find index of var_id in the factor
284  SGVector<int32_t> fvar_set = fvars.find(var_id);
285  ASSERT(fvar_set.vlen == 1);
286  int32_t var_id_index = fvar_set[0];
287 
288  std::vector<float64_t> r_f2v(fenrgs.size(), 0);
289  std::vector<float64_t> r_f2v_max(cards[var_id],
290  -std::numeric_limits<float64_t>::infinity());
291 
292  // TODO: optimize with index_from_new_state()
293  // marginalization
294  // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state])
295  for (int32_t ei = 0; ei < fenrgs.size(); ei++)
296  {
297  r_f2v[ei] = -fenrgs[ei];
298 
299  for (int32_t vi = 0; vi < fvars.size(); vi++)
300  {
301  if (vi == var_id_index)
302  continue;
303 
304  uint32_t adj_msg = m_msg_map_var[fvars[vi]];
305  int32_t adj_var_state = ftype->state_from_index(ei, vi);
306 
307  r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
308  }
309 
310  // find max value for each state of var_id
311  int32_t var_state = ftype->state_from_index(ei, var_id_index);
312  if (r_f2v[ei] > r_f2v_max[var_state])
313  r_f2v_max[var_state] = r_f2v[ei];
314  }
315 
316  // in max-product, final r_f2v = r_f2v_max
317  for (int32_t si = 0; si < cards[var_id]; si++)
318  m_fw_msgs[mi][si] = r_f2v_max[si];
319 
320  SG_UNREF(ftype);
321  }
322  }
323  SG_UNREF(facs);
324 
325  // -energy = max(sum_{f} r_f2root)
326  m_map_energy = 0;
327  for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
328  {
329  if (!m_is_root[ri])
330  continue;
331 
332  const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
333  std::vector<float64_t> rmarg(cards[ri], 0);
334  for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
335  {
336  std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
337  rmarg.begin(),
338  rmarg.begin(),
339  std::plus<float64_t>());
340  }
341 
342  m_map_energy += *std::max_element(rmarg.begin(), rmarg.end());
343  }
344  SG_DEBUG("***leave bottom_up_pass().\n");
345 }
346 
347 void CTreeMaxProduct::top_down_pass()
348 {
349  SG_DEBUG("\n***enter top_down_pass().\n");
350  int32_t minf = std::numeric_limits<int32_t>::max();
351  CDynamicObjectArray* facs = m_fg->get_factors();
352  SGVector<int32_t> cards = m_fg->get_cardinalities();
353 
354  // init backward msgs to 0
355  m_bw_msgs.resize(m_msg_order.size());
356  for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
357  {
358  // msg size is determined by var node of msg edge
359  m_bw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
360  std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(), 0);
361  }
362 
363  // init states to max infinity
364  m_states.resize(cards.size());
365  std::fill(m_states.begin(), m_states.end(), minf);
366 
367  // infer states of roots first since marginal distributions of
368  // root variables are ready after bottom-up pass
369  for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
370  {
371  if (!m_is_root[ri])
372  continue;
373 
374  const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
375  std::vector<float64_t> rmarg(cards[ri], 0);
376  for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
377  {
378  // rmarg += m_fw_msgs[*cit]
379  std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
380  rmarg.begin(),
381  rmarg.begin(),
382  std::plus<float64_t>());
383  }
384 
385  // argmax
386  m_states[ri] = static_cast<int32_t>(
387  std::max_element(rmarg.begin(), rmarg.end())
388  - rmarg.begin());
389  }
390 
391  // pass msgs down to leaf
392  // if factor <- var edge
393  // compute q_v2f
394  // compute marginal of f
395  // else var <- factor edge
396  // compute r_f2v
397  for (int32_t mi = (int32_t)(m_msg_order.size()-1); mi >= 0; --mi)
398  {
399  SG_DEBUG("mi = %d, mtype: %d %d <- %d\n", mi,
400  m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
401 
402  if (m_msg_order[mi]->mtype == FAC_TO_VAR) // factor <- var
403  {
404  int32_t fac_id = m_msg_order[mi]->child;
405  int32_t var_id = m_msg_order[mi]->parent;
406 
407  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
408  CTableFactorType* ftype = fac->get_factor_type();
409  SGVector<int32_t> fvars = fac->get_variables();
410  SGVector<float64_t> fenrgs = fac->get_energies();
411  SG_UNREF(fac);
412 
413  // find index of var_id in the factor
414  SGVector<int32_t> fvar_set = fvars.find(var_id);
415  ASSERT(fvar_set.vlen == 1);
416  int32_t var_id_index = fvar_set[0];
417 
418  // q_v2f = r_bw_parent2v + sum_{child!=f} r_fw_child2v
419  // make sure the state of var_id has been inferred (factor marginalization)
420  // s.t. this msg computation will condition on the known state
421  ASSERT(m_states[var_id] != minf);
422 
423  // parent msg: r_bw_parent2v
424  if (m_is_root[var_id] == 0)
425  {
426  uint32_t parent_msg = m_msg_map_var[var_id];
427  std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(),
428  m_bw_msgs[parent_msg][m_states[var_id]]);
429  }
430 
431  // siblings: sum_{child!=f} r_fw_child2v
432  const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
433  for (std::set<uint32_t>::const_iterator cit = msgset_var.begin();
434  cit != msgset_var.end(); cit++)
435  {
436  if (m_msg_order[*cit]->child == fac_id)
437  continue;
438 
439  for (uint32_t xi = 0; xi < m_bw_msgs[mi].size(); xi++)
440  m_bw_msgs[mi][xi] += m_fw_msgs[*cit][m_states[var_id]];
441  }
442 
443  // m_states from maximizing marginal distributions of fac_id
444  // mu(f) = -E(v_state) + sum_v q_v2f
445  std::vector<float64_t> marg(fenrgs.size(), 0);
446  for (uint32_t ei = 0; ei < marg.size(); ei++)
447  {
448  int32_t nei = ftype->index_from_new_state(ei, var_id_index, m_states[var_id]);
449  marg[ei] = -fenrgs[nei];
450 
451  for (int32_t vi = 0; vi < fvars.size(); vi++)
452  {
453  if (vi == var_id_index)
454  {
455  int32_t var_id_state = ftype->state_from_index(ei, var_id_index);
456  if (m_states[var_id] != minf)
457  var_id_state = m_states[var_id];
458 
459  marg[ei] += m_bw_msgs[mi][var_id_state];
460  }
461  else
462  {
463  uint32_t adj_id = fvars[vi];
464  uint32_t adj_msg = m_msg_map_var[adj_id];
465  int32_t adj_id_state = ftype->state_from_index(ei, vi);
466 
467  marg[ei] += m_fw_msgs[adj_msg][adj_id_state];
468  }
469  }
470  }
471 
472  int32_t ei_max = static_cast<int32_t>(
473  std::max_element(marg.begin(), marg.end())
474  - marg.begin());
475 
476  // infer states of neiboring vars of f
477  for (int32_t vi = 0; vi < fvars.size(); vi++)
478  {
479  int32_t nvar_id = fvars[vi];
480  // usually parent node has been inferred
481  if (m_states[nvar_id] != minf)
482  continue;
483 
484  int32_t nvar_id_state = ftype->state_from_index(ei_max, vi);
485  m_states[nvar_id] = nvar_id_state;
486  }
487 
488  SG_UNREF(ftype);
489  }
490  else // var <- factor
491  {
492  int32_t var_id = m_msg_order[mi]->child;
493  int32_t fac_id = m_msg_order[mi]->parent;
494 
495  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
496  CTableFactorType* ftype = fac->get_factor_type();
497  SGVector<int32_t> fvars = fac->get_variables();
498  SGVector<float64_t> fenrgs = fac->get_energies();
499  SG_UNREF(fac);
500 
501  // find index of var_id in the factor
502  SGVector<int32_t> fvar_set = fvars.find(var_id);
503  ASSERT(fvar_set.vlen == 1);
504  int32_t var_id_index = fvar_set[0];
505 
506  uint32_t msg_parent = m_msg_map_fac[fac_id];
507  int32_t var_parent = m_msg_order[msg_parent]->parent;
508 
509  std::vector<float64_t> r_f2v(fenrgs.size());
510  std::vector<float64_t> r_f2v_max(cards[var_id],
511  -std::numeric_limits<float64_t>::infinity());
512 
513  // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state])
514  for (int32_t ei = 0; ei < fenrgs.size(); ei++)
515  {
516  r_f2v[ei] = -fenrgs[ei];
517 
518  for (int32_t vi = 0; vi < fvars.size(); vi++)
519  {
520  if (vi == var_id_index)
521  continue;
522 
523  if (fvars[vi] == var_parent)
524  {
525  ASSERT(m_states[var_parent] != minf);
526  r_f2v[ei] += m_bw_msgs[msg_parent][m_states[var_parent]];
527  }
528  else
529  {
530  int32_t adj_id = fvars[vi];
531  uint32_t adj_msg = m_msg_map_var[adj_id];
532  int32_t adj_var_state = ftype->state_from_index(ei, vi);
533 
534  if (m_states[adj_id] != minf)
535  adj_var_state = m_states[adj_id];
536 
537  r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
538  }
539  }
540 
541  // max marginalization
542  int32_t var_id_state = ftype->state_from_index(ei, var_id_index);
543  if (r_f2v[ei] > r_f2v_max[var_id_state])
544  r_f2v_max[var_id_state] = r_f2v[ei];
545  }
546 
547  for (int32_t si = 0; si < cards[var_id]; si++)
548  m_bw_msgs[mi][si] = r_f2v_max[si];
549 
550  SG_UNREF(ftype);
551  }
552  } // end for msg edge
553 
554  SG_UNREF(facs);
555  SG_DEBUG("***leave top_down_pass().\n");
556 }
557 
CTableFactorType * get_factor_type() const
Definition: Factor.cpp:95
const SGVector< int32_t > get_variables() const
Definition: Factor.cpp:107
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
Class CMAPInferImpl abstract class of MAP inference implementation.
Definition: MAPInference.h:98
SGVector< float64_t > get_energies() const
Definition: Factor.cpp:169
int32_t size() const
Definition: SGVector.h:115
index_t vlen
Definition: SGVector.h:494
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
Class CDisjointSet data structure for linking graph nodes It's easy to identify connected graph...
Definition: DisjointSet.h:26
int32_t index_from_new_state(int32_t old_ei, int32_t var_index, int32_t var_state) const
Definition: FactorType.cpp:180
int32_t state_from_index(int32_t ei, int32_t var_index) const
Definition: FactorType.cpp:155
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
Class CTableFactorType the way that store assignments of variables and energies in a table or a multi...
Definition: FactorType.h:122
CSGObject * get_element(int32_t index) const
Matrix::Scalar max(Matrix m)
Definition: Redux.h:66
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:132
SGVector< index_t > find(T elem)
Definition: SGVector.cpp:809
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...
Definition: Factor.h:89

SHOGUN 机器学习工具包 - 项目文档