21 CBeliefPropagation::CBeliefPropagation()
24 SG_UNSTABLE(
"CBeliefPropagation::CBeliefPropagation()",
"\n");
32 CBeliefPropagation::~CBeliefPropagation()
38 SG_ERROR(
"%s::inference(): please use TreeMaxProduct or LoopyMaxProduct!\n", get_name());
44 CTreeMaxProduct::CTreeMaxProduct()
45 : CBeliefPropagation()
47 SG_UNSTABLE(
"CTreeMaxProduct::CTreeMaxProduct()",
"\n");
53 : CBeliefPropagation(fg)
64 m_fg->connect_components();
66 get_message_order(m_msg_order, m_is_root);
71 for (uint32_t mi = 0; mi < m_msg_order.size(); mi++)
73 if (m_msg_order[mi]->mtype == VAR_TO_FAC)
76 m_msg_map_var[m_msg_order[mi]->child] = mi;
81 m_msg_map_fac[m_msg_order[mi]->child] = mi;
83 m_msgset_map_var[m_msg_order[mi]->parent].insert(mi);
89 CTreeMaxProduct::~CTreeMaxProduct()
91 if (!m_msg_order.empty())
93 for (std::vector<MessageEdge*>::iterator it = m_msg_order.begin(); it != m_msg_order.end(); ++it)
98 void CTreeMaxProduct::init()
100 m_msg_order = std::vector<MessageEdge*>(m_fg->get_num_edges(), (MessageEdge*) NULL);
101 m_is_root = std::vector<bool>(m_fg->get_cardinalities().size(),
false);
102 m_fw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
103 std::vector<float64_t>());
104 m_bw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
105 std::vector<float64_t>());
106 m_states = std::vector<int32_t>(m_fg->get_cardinalities().size(), 0);
108 m_msg_map_var = msg_map_type();
109 m_msg_map_fac = msg_map_type();
110 m_msgset_map_var = msgset_map_type();
113 void CTreeMaxProduct::get_message_order(std::vector<MessageEdge*>& order,
114 std::vector<bool>& is_root)
const
116 ASSERT(m_fg->is_acyclic_graph());
120 if (!dset->get_connected())
123 SG_ERROR(
"%s::get_root_indicators(): run connect_components() first!\n", get_name());
126 int32_t num_vars = m_fg->get_cardinalities().size();
127 if (is_root.size() != (uint32_t)num_vars)
128 is_root.resize(num_vars);
130 std::fill(is_root.begin(), is_root.end(),
false);
132 for (int32_t vi = 0; vi < num_vars; vi++)
133 is_root[dset->find_set(vi)] =
true;
136 ASSERT(std::accumulate(is_root.begin(), is_root.end(), 0) >= 1);
140 var_factor_map_type vf_map;
147 for (int32_t vi = 0; vi < vars.
size(); vi++)
148 vf_map.insert(var_factor_map_type::value_type(vars[vi], fi));
153 std::stack<GraphNode*> node_stack;
155 for (uint32_t ni = 0; ni < is_root.size(); ni++)
160 node_stack.push(
new GraphNode(ni, VAR_NODE, -1));
164 if (order.size() != (uint32_t)(m_fg->get_num_edges()))
165 order.resize(m_fg->get_num_edges());
168 int32_t eid = m_fg->get_num_edges() - 1;
169 while (!node_stack.empty())
171 GraphNode*
node = node_stack.top();
174 if (node->node_type == VAR_NODE)
176 typedef var_factor_map_type::const_iterator const_iter;
177 std::pair<const_iter, const_iter> adj_factors = vf_map.equal_range(node->node_id);
178 for (const_iter mi = adj_factors.first; mi != adj_factors.second; ++mi)
180 int32_t adj_factor_id = mi->second;
181 if (adj_factor_id == node->parent)
184 order[eid--] =
new MessageEdge(FAC_TO_VAR, adj_factor_id, node->node_id);
186 node_stack.push(
new GraphNode(adj_factor_id, FAC_NODE, node->node_id));
195 for (int32_t vi = 0; vi < vars.
size(); vi++)
197 if (vars[vi] == node->parent)
200 order[eid--] =
new MessageEdge(VAR_TO_FAC, vars[vi], node->node_id);
202 node_stack.push(
new GraphNode(vars[vi], VAR_NODE, node->node_id));
214 REQUIRE(assignment.
size() == m_fg->get_cardinalities().size(),
215 "%s::inference(): the output assignment should be prepared as"
216 "the same size as variables!\n", get_name());
221 for (int32_t vi = 0; vi < assignment.
size(); vi++)
222 assignment[vi] = m_states[vi];
224 SG_DEBUG(
"fg.evaluate_energy(assignment) = %f\n", m_fg->evaluate_energy(assignment));
225 SG_DEBUG(
"minimized energy = %f\n", -m_map_energy);
227 return -m_map_energy;
230 void CTreeMaxProduct::bottom_up_pass()
232 SG_DEBUG(
"\n***enter bottom_up_pass().\n");
237 m_fw_msgs.resize(m_msg_order.size());
238 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
241 m_fw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
242 std::fill(m_fw_msgs[mi].begin(), m_fw_msgs[mi].end(), 0);
253 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
255 SG_DEBUG(
"mi = %d, mtype: %d %d -> %d\n", mi,
256 m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
258 if (m_msg_order[mi]->mtype == VAR_TO_FAC)
260 uint32_t var_id = m_msg_order[mi]->child;
261 const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
264 for (std::set<uint32_t>::const_iterator cit = msgset_var.begin(); cit != msgset_var.end(); cit++)
266 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
267 m_fw_msgs[mi].begin(),
268 m_fw_msgs[mi].begin(),
269 std::plus<float64_t>());
274 int32_t fac_id = m_msg_order[mi]->child;
275 int32_t var_id = m_msg_order[mi]->parent;
286 int32_t var_id_index = fvar_set[0];
288 std::vector<float64_t> r_f2v(fenrgs.
size(), 0);
289 std::vector<float64_t> r_f2v_max(cards[var_id],
290 -std::numeric_limits<float64_t>::infinity());
295 for (int32_t ei = 0; ei < fenrgs.
size(); ei++)
297 r_f2v[ei] = -fenrgs[ei];
299 for (int32_t vi = 0; vi < fvars.
size(); vi++)
301 if (vi == var_id_index)
304 uint32_t adj_msg = m_msg_map_var[fvars[vi]];
307 r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
312 if (r_f2v[ei] > r_f2v_max[var_state])
313 r_f2v_max[var_state] = r_f2v[ei];
317 for (int32_t si = 0; si < cards[var_id]; si++)
318 m_fw_msgs[mi][si] = r_f2v_max[si];
327 for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
332 const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
333 std::vector<float64_t> rmarg(cards[ri], 0);
334 for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
336 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
339 std::plus<float64_t>());
342 m_map_energy += *std::max_element(rmarg.begin(), rmarg.end());
344 SG_DEBUG(
"***leave bottom_up_pass().\n");
347 void CTreeMaxProduct::top_down_pass()
349 SG_DEBUG(
"\n***enter top_down_pass().\n");
355 m_bw_msgs.resize(m_msg_order.size());
356 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
359 m_bw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
360 std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(), 0);
364 m_states.resize(cards.
size());
365 std::fill(m_states.begin(), m_states.end(), minf);
369 for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
374 const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
375 std::vector<float64_t> rmarg(cards[ri], 0);
376 for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
379 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
382 std::plus<float64_t>());
386 m_states[ri] =
static_cast<int32_t
>(
387 std::max_element(rmarg.begin(), rmarg.end())
397 for (int32_t mi = (int32_t)(m_msg_order.size()-1); mi >= 0; --mi)
399 SG_DEBUG(
"mi = %d, mtype: %d %d <- %d\n", mi,
400 m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
402 if (m_msg_order[mi]->mtype == FAC_TO_VAR)
404 int32_t fac_id = m_msg_order[mi]->child;
405 int32_t var_id = m_msg_order[mi]->parent;
416 int32_t var_id_index = fvar_set[0];
421 ASSERT(m_states[var_id] != minf);
424 if (m_is_root[var_id] == 0)
426 uint32_t parent_msg = m_msg_map_var[var_id];
427 std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(),
428 m_bw_msgs[parent_msg][m_states[var_id]]);
432 const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
433 for (std::set<uint32_t>::const_iterator cit = msgset_var.begin();
434 cit != msgset_var.end(); cit++)
436 if (m_msg_order[*cit]->child == fac_id)
439 for (uint32_t xi = 0; xi < m_bw_msgs[mi].size(); xi++)
440 m_bw_msgs[mi][xi] += m_fw_msgs[*cit][m_states[var_id]];
445 std::vector<float64_t> marg(fenrgs.
size(), 0);
446 for (uint32_t ei = 0; ei < marg.size(); ei++)
449 marg[ei] = -fenrgs[nei];
451 for (int32_t vi = 0; vi < fvars.
size(); vi++)
453 if (vi == var_id_index)
456 if (m_states[var_id] != minf)
457 var_id_state = m_states[var_id];
459 marg[ei] += m_bw_msgs[mi][var_id_state];
463 uint32_t adj_id = fvars[vi];
464 uint32_t adj_msg = m_msg_map_var[adj_id];
467 marg[ei] += m_fw_msgs[adj_msg][adj_id_state];
472 int32_t ei_max =
static_cast<int32_t
>(
473 std::max_element(marg.begin(), marg.end())
477 for (int32_t vi = 0; vi < fvars.
size(); vi++)
479 int32_t nvar_id = fvars[vi];
481 if (m_states[nvar_id] != minf)
485 m_states[nvar_id] = nvar_id_state;
492 int32_t var_id = m_msg_order[mi]->child;
493 int32_t fac_id = m_msg_order[mi]->parent;
504 int32_t var_id_index = fvar_set[0];
506 uint32_t msg_parent = m_msg_map_fac[fac_id];
507 int32_t var_parent = m_msg_order[msg_parent]->parent;
509 std::vector<float64_t> r_f2v(fenrgs.
size());
510 std::vector<float64_t> r_f2v_max(cards[var_id],
511 -std::numeric_limits<float64_t>::infinity());
514 for (int32_t ei = 0; ei < fenrgs.
size(); ei++)
516 r_f2v[ei] = -fenrgs[ei];
518 for (int32_t vi = 0; vi < fvars.
size(); vi++)
520 if (vi == var_id_index)
523 if (fvars[vi] == var_parent)
525 ASSERT(m_states[var_parent] != minf);
526 r_f2v[ei] += m_bw_msgs[msg_parent][m_states[var_parent]];
530 int32_t adj_id = fvars[vi];
531 uint32_t adj_msg = m_msg_map_var[adj_id];
534 if (m_states[adj_id] != minf)
535 adj_var_state = m_states[adj_id];
537 r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
543 if (r_f2v[ei] > r_f2v_max[var_id_state])
544 r_f2v_max[var_id_state] = r_f2v[ei];
547 for (int32_t si = 0; si < cards[var_id]; si++)
548 m_bw_msgs[mi][si] = r_f2v_max[si];
555 SG_DEBUG(
"***leave top_down_pass().\n");
CTableFactorType * get_factor_type() const
const SGVector< int32_t > get_variables() const
int32_t get_num_elements() const
Class CMAPInferImpl abstract class of MAP inference implementation.
SGVector< float64_t > get_energies() const
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
Class CDisjointSet data structure for linking graph nodes It's easy to identify connected graph...
int32_t index_from_new_state(int32_t old_ei, int32_t var_index, int32_t var_state) const
int32_t state_from_index(int32_t ei, int32_t var_index) const
all of classes and functions are contained in the shogun namespace
Class CFactorGraph a factor graph is a structured input in general.
Class CTableFactorType the way that store assignments of variables and energies in a table or a multi...
CSGObject * get_element(int32_t index) const
Matrix::Scalar max(Matrix m)
#define SG_UNSTABLE(func,...)
SGVector< index_t > find(T elem)
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...