SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
GEMPLP.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2014 Jiaolong Xu
8  * Copyright (C) 2014 Jiaolong Xu
9  */
10 
12 #include <shogun/io/SGIO.h>
13 #include <algorithm>
14 
15 using namespace shogun;
16 using namespace std;
17 
19  : CMAPInferImpl()
20 {
21  m_fg = NULL;
22  m_factors = NULL;
23 }
24 
26  : CMAPInferImpl(fg),
27  m_param(param)
28 {
29  ASSERT(m_fg != NULL);
30 
31  init();
32 }
33 
35 {
36  if(m_factors != NULL)
38 }
39 
40 void CGEMPLP::init()
41 {
42  SGVector<int32_t> fg_var_sizes = m_fg->get_cardinalities();
44 
45  int32_t num_factors = m_factors->get_num_elements();
46  m_region_intersections.resize(num_factors);
47 
48  // get all the intersections
49  for (int32_t i = 0; i < num_factors; i++)
50  {
51  CFactor* factor_i = dynamic_cast<CFactor*>(m_factors->get_element(i));
52  SGVector<int32_t> region_i = factor_i->get_variables();
53  SG_UNREF(factor_i);
54 
55  for (int32_t j = i; j < num_factors; j++)
56  {
57  CFactor* factor_j = dynamic_cast<CFactor*>(m_factors->get_element(j));
58  SGVector<int32_t> region_j = factor_j->get_variables();
59  SG_UNREF(factor_j);
60 
61  const int32_t k = find_intersection_index(region_i, region_j);
62  if (k < 0) continue;
63 
64  m_region_intersections[i].insert(k);
65 
66  if (j != i)
67  m_region_intersections[j].insert(k);
68  }
69  }
70 
71  m_region_inds_intersections.resize(num_factors);
72  m_msgs_from_region.resize(num_factors);
73  m_theta_region.resize(num_factors);
74 
75  for (int32_t c = 0; c < num_factors; c++)
76  {
77  CFactor* factor_c = dynamic_cast<CFactor*>(m_factors->get_element(c));
78  SGVector<int32_t> vars_c = factor_c->get_variables();
79  SG_UNREF(factor_c);
80 
82  m_msgs_from_region[c].resize(m_region_intersections[c].size());
83 
84  int32_t s = 0;
85 
86  for (std::set<int>::iterator t = m_region_intersections[c].begin();
87  t != m_region_intersections[c].end(); t++)
88  {
89  SGVector<int32_t> curr_intersection = m_all_intersections[*t];
90  SGVector<int32_t> inds_s(curr_intersection.size());
91  SGVector<int32_t> dims_array(curr_intersection.size());
92 
93  for (int32_t i = 0; i < inds_s.size(); i++)
94  {
95  inds_s[i] = vars_c.find(curr_intersection[i])[0];
96  REQUIRE(inds_s[i] >= 0,
97  "Intersection contains variable %d which is not in the region %d", curr_intersection[i], c);
98 
99  dims_array[i] = fg_var_sizes[curr_intersection[i]];
100  }
101 
102  // initialize indices of intersections inside the region
103  m_region_inds_intersections[c][s] = inds_s;
104 
105  // initialize messages from region and set it 0
106  SGNDArray<float64_t> message(dims_array);
107  message.set_const(0);
108  m_msgs_from_region[c][s] = message.clone();
109  s++;
110  }
111 
112  // initialize potential on region
114  }
115 
116  // initialize messages in intersections and set it 0
118 
119  for (uint32_t i = 0; i < m_all_intersections.size(); i++)
120  {
121  SGVector<int32_t> vars_intersection = m_all_intersections[i];
122  SGVector<int32_t> dims_array(vars_intersection.size());
123 
124  for (int32_t j = 0; j < dims_array.size(); j++)
125  dims_array[j] = fg_var_sizes[vars_intersection[j]];
126 
127  SGNDArray<float64_t> curr_array(dims_array);
128  curr_array.set_const(0);
129  m_msgs_into_intersections[i] = curr_array.clone();
130  }
131 }
132 
134 {
135  SGVector<float64_t> energies = factor->get_energies();
136  SGVector<int32_t> cards = factor->get_cardinalities();
137 
138  SGNDArray<float64_t> message(cards);
139 
140  if (cards.size() == 1)
141  {
142  for (int32_t i = 0; i < energies.size(); i++)
143  message.array[i] = - energies[i];
144  }
145  else if (cards.size() == 2)
146  {
147  for (int32_t y = 0; y < cards[1]; y++)
148  for (int32_t x = 0; x < cards[0]; x++)
149  message.array[x*cards[1]+y] = - energies[y*cards[0]+x];
150  }
151  else
152  SG_ERROR("Index issue has not been solved for higher order (>=3) factors.");
153 
154  return message.clone();
155 }
156 
158 {
159  vector<int32_t> tmp;
160 
161  for (int32_t i = 0; i < region_A.size(); i++)
162  {
163  for (int32_t j = 0; j < region_B.size(); j++)
164  {
165  if (region_A[i] == region_B[j])
166  tmp.push_back(region_A[i]);
167  }
168  }
169 
170  // return -1 if intersetion is empty
171  if (tmp.size() == 0) return -1;
172 
173 
174  SGVector<int32_t> sAB(tmp.size());
175  for (uint32_t i = 0; i < tmp.size(); i++)
176  sAB[i] = tmp[i];
177 
178  // find (or add) intersection set
179  int32_t k;
180  for (k = 0; k < (int32_t)m_all_intersections.size(); k++)
181  if (m_all_intersections[k].equals(sAB))
182  break;
183 
184  if (k == (int32_t)m_all_intersections.size())
185  m_all_intersections.push_back(sAB);
186 
187  return k;
188 }
189 
191 {
192  REQUIRE(assignment.size() == m_fg->get_cardinalities().size(),
193  "%s::inference(): the output assignment should be prepared as"
194  "the same size as variables!\n", get_name());
195 
196  // iterate over message loop
197  SG_SDEBUG("Running MPLP for %d iterations\n", m_param.m_max_iter);
198 
199  float64_t last_obj = CMath::INFTY;
200 
201  // block coordinate desent, outer loop
202  for (int32_t it = 0; it < m_param.m_max_iter; ++it)
203  {
204  // update message, iterate over all regions
205  for (int32_t c = 0; c < m_factors->get_num_elements(); c++)
206  {
207  CFactor* factor_c = dynamic_cast<CFactor*>(m_factors->get_element(c));
208  SGVector<int32_t> vars = factor_c->get_variables();
209  SG_UNREF(factor_c);
210 
211  if (vars.size() == 1 && it > 0)
212  continue;
213 
214  update_messages(c);
215  }
216 
217  // calculate the objective value
218  float64_t obj = 0;
219  int32_t max_at;
220 
221  for (uint32_t s = 0; s < m_msgs_into_intersections.size(); s++)
222  {
223  obj += m_msgs_into_intersections[s].max_element(max_at);
224 
225  if (m_all_intersections[s].size() == 1)
226  assignment[m_all_intersections[s][0]] = max_at;
227  }
228 
229  // get the value of the decoded solution
230  float64_t int_val = 0;
231 
232  // iterates over factors
233  for (int32_t c = 0; c < m_factors->get_num_elements(); c++)
234  {
235  CFactor* factor = dynamic_cast<CFactor*>(m_factors->get_element(c));
236  SGVector<int32_t> vars = factor->get_variables();
237  SGVector<int32_t> var_assignment(vars.size());
238 
239  for (int32_t i = 0; i < vars.size(); i++)
240  var_assignment[i] = assignment[vars[i]];
241 
242  // add value from current factor
243  int_val += m_theta_region[c].get_value(var_assignment);
244 
245  SG_UNREF(factor);
246  }
247 
248  float64_t obj_del = last_obj - obj;
249  float64_t int_gap = obj - int_val;
250 
251  SG_SDEBUG("Iter= %d Objective=%f ObjBest=%f ObjDel=%f Gap=%f \n", (it + 1), obj, int_val, obj_del, int_gap);
252 
253  if (obj_del < m_param.m_obj_del_thr && it > 16)
254  break;
255 
256  if (int_gap < m_param.m_int_gap_thr)
257  break;
258 
259  last_obj = obj;
260  }
261 
262  float64_t energy = m_fg->evaluate_energy(assignment);
263  SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", energy);
264 
265  return energy;
266 }
267 
268 void CGEMPLP::update_messages(int32_t id_region)
269 {
270  REQUIRE(m_factors != NULL, "Factors are not set!\n");
271 
272  REQUIRE(m_factors->get_num_elements() > id_region,
273  "Region id (%d) exceeds the factor elements' length (%d)!\n",
274  id_region, m_factors->get_num_elements());
275 
276  CFactor* factor = dynamic_cast<CFactor*>(m_factors->get_element(id_region));
277  SGVector<int32_t> vars = factor->get_variables();
278  SGVector<int32_t> cards = factor->get_cardinalities();
279  SGNDArray<float64_t> lam_sum(cards);
280 
281  if (m_theta_region[id_region].len_array == 0)
282  lam_sum.set_const(0);
283  else
284  lam_sum = m_theta_region[id_region].clone();
285 
286  int32_t num_intersections = m_region_intersections[id_region].size();
287  vector<SGNDArray<float64_t> > lam_minus; // substract message: \lambda_s^{-c}(x_s)
288  // \sum_{\hat{s}} \lambda_{\hat{s}}^{-c}(x_{\hat{s}}) + \theta_c(x_c)
289  int32_t s = 0;
290 
291  for (std::set<int32_t>::iterator t = m_region_intersections[id_region].begin();
292  t != m_region_intersections[id_region].end(); t++)
293  {
294  int32_t id_intersection = *t;
295  SGNDArray<float64_t> tmp = m_msgs_into_intersections[id_intersection].clone();
296  tmp -= m_msgs_from_region[id_region][s];
297 
298  lam_minus.push_back(tmp);
299 
300  if (vars.size() == (int32_t)m_region_inds_intersections[id_region][s].size())
301  lam_sum += tmp;
302  else
303  {
304  SGNDArray<float64_t> tmp_expand(lam_sum.get_dimensions());
305  tmp.expand(tmp_expand, m_region_inds_intersections[id_region][s]);
306  lam_sum += tmp_expand;
307  }
308 
309  // take out the old incoming message: \lambda_{c \to s}(x_s)
310  m_msgs_into_intersections[id_intersection] -= m_msgs_from_region[id_region][s];
311  s++;
312  }
313 
314  s = 0;
315 
316  for (std::set<int32_t>::iterator t = m_region_intersections[id_region].begin();
317  t != m_region_intersections[id_region].end(); t++)
318  {
319  // maximazation: \max_{x_c} \sum_{\hat{s}} \lambda_{\hat{s}}^{-c}(x_{\hat{s}}) + \theta_c(x_c)
320  SGNDArray<float64_t> lam_max(lam_minus[s].get_dimensions());
321  max_in_subdimension(lam_sum, m_region_inds_intersections[id_region][s], lam_max);
322  int32_t id_intersection = *t;
323  // weighted sum
324  lam_max *= 1.0/num_intersections;
325  m_msgs_from_region[id_region][s] = lam_max.clone();
326  m_msgs_from_region[id_region][s] -= lam_minus[s];
327  // put in new message
328  m_msgs_into_intersections[id_intersection] += m_msgs_from_region[id_region][s];
329  s++;
330  }
331 
332  SG_UNREF(factor);
333 }
334 
336 {
337  // If the subset equals the target array then maximizing would
338  // give us the target array (assuming there is no reordering)
339  if (subset_inds.size() == tar_arr.num_dims)
340  {
341  max_res = tar_arr.clone();
342  return;
343  }
344  else
345  max_res.set_const(-CMath::INFTY);
346 
347  // Go over all values of the target array. For each check if its
348  // value on the subset is larger than the current max
349  SGVector<int32_t> inds_for_tar(tar_arr.num_dims);
350  inds_for_tar.zero();
351 
352  for (int32_t vi = 0; vi < tar_arr.len_array; vi++)
353  {
354  int32_t y = 0;
355 
356  if (subset_inds.size() == 1)
357  y = inds_for_tar[subset_inds[0]];
358  else if (subset_inds.size() == 2)
359  {
360  int32_t ind1 = subset_inds[0];
361  int32_t ind2 = subset_inds[1];
362  y = inds_for_tar[ind1] * max_res.dims[1] + inds_for_tar[ind2];
363  }
364  max_res[y] = max(max_res[y], tar_arr.array[vi]);
365  tar_arr.next_index(inds_for_tar);
366  }
367 }
std::vector< SGVector< int32_t > > m_all_intersections
Definition: GEMPLP.h:138
float64_t evaluate_energy(const SGVector< int32_t > state) const
float64_t m_int_gap_thr
Definition: GEMPLP.h:63
virtual ~CGEMPLP()
Definition: GEMPLP.cpp:34
std::vector< std::vector< SGNDArray< float64_t > > > m_msgs_from_region
Definition: GEMPLP.h:146
static const float64_t INFTY
infinity
Definition: Math.h:2048
std::vector< std::vector< SGVector< int32_t > > > m_region_inds_intersections
Definition: GEMPLP.h:142
SGNDArray< T > clone() const
Definition: SGNDArray.cpp:105
const SGVector< int32_t > get_variables() const
Definition: Factor.cpp:107
Definition: basetag.h:132
CDynamicObjectArray * get_factors() const
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
Class CMAPInferImpl abstract class of MAP inference implementation.
Definition: MAPInference.h:98
SGVector< float64_t > get_energies() const
Definition: Factor.cpp:169
int32_t size() const
Definition: SGVector.h:113
CFactorGraph * m_fg
Definition: MAPInference.h:128
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
Parameter m_param
Definition: GEMPLP.h:134
index_t num_dims
Definition: SGNDArray.h:180
virtual float64_t inference(SGVector< int32_t > assignment)
Definition: GEMPLP.cpp:190
void max_in_subdimension(SGNDArray< float64_t > tar_arr, SGVector< int32_t > &subset_inds, SGNDArray< float64_t > &max_res) const
Definition: GEMPLP.cpp:335
virtual const char * get_name() const
Definition: GEMPLP.h:81
virtual bool equals(CSGObject *other, float64_t accuracy=0.0, bool tolerant=false)
Definition: SGObject.cpp:651
SGNDArray< float64_t > convert_energy_to_potential(CFactor *factor)
Definition: GEMPLP.cpp:133
void set_const(T const_elem)
Definition: SGNDArray.cpp:146
CDynamicObjectArray * m_factors
Definition: GEMPLP.h:136
#define SG_UNREF(x)
Definition: SGObject.h:55
index_t len_array
Definition: SGNDArray.h:183
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
int32_t find_intersection_index(SGVector< int32_t > region_A, SGVector< int32_t > region_B)
Definition: GEMPLP.cpp:157
#define SG_SDEBUG(...)
Definition: SGIO.h:168
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
index_t * dims
Definition: SGNDArray.h:177
std::vector< std::set< int32_t > > m_region_intersections
Definition: GEMPLP.h:140
std::vector< SGNDArray< float64_t > > m_theta_region
Definition: GEMPLP.h:148
CSGObject * get_element(int32_t index) const
Matrix::Scalar max(Matrix m)
Definition: Redux.h:68
SGVector< int32_t > get_cardinalities() const
void next_index(SGVector< index_t > &curr_index) const
Definition: SGNDArray.cpp:286
std::vector< SGNDArray< float64_t > > m_msgs_into_intersections
Definition: GEMPLP.h:144
SGVector< index_t > find(T elem)
Definition: SGVector.cpp:807
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...
Definition: Factor.h:89
const SGVector< int32_t > get_cardinalities() const
Definition: Factor.cpp:122
void expand(SGNDArray &big_array, SGVector< index_t > &axes)
Definition: SGNDArray.cpp:303

SHOGUN Machine Learning Toolbox - Documentation