SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
GEMPLP.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2014 Jiaolong Xu
8  * Copyright (C) 2014 Jiaolong Xu
9  */
10 
12 #include <shogun/io/SGIO.h>
13 #include <algorithm>
14 
15 using namespace shogun;
16 
18  : CMAPInferImpl()
19 {
20  m_fg = NULL;
21  m_factors = NULL;
22 }
23 
25  : CMAPInferImpl(fg),
26  m_param(param)
27 {
28  ASSERT(m_fg != NULL);
29 
30  init();
31 }
32 
34 {
35  if(m_factors != NULL)
37 }
38 
39 void CGEMPLP::init()
40 {
41  SGVector<int32_t> fg_var_sizes = m_fg->get_cardinalities();
43 
44  int32_t num_factors = m_factors->get_num_elements();
45  m_region_intersections.resize(num_factors);
46 
47  // get all the intersections
48  for (int32_t i = 0; i < num_factors; i++)
49  {
50  CFactor* factor_i = dynamic_cast<CFactor*>(m_factors->get_element(i));
51  SGVector<int32_t> region_i = factor_i->get_variables();
52  SG_UNREF(factor_i);
53 
54  for (int32_t j = i; j < num_factors; j++)
55  {
56  CFactor* factor_j = dynamic_cast<CFactor*>(m_factors->get_element(j));
57  SGVector<int32_t> region_j = factor_j->get_variables();
58  SG_UNREF(factor_j);
59 
60  const int32_t k = find_intersection_index(region_i, region_j);
61  if (k < 0) continue;
62 
63  m_region_intersections[i].insert(k);
64 
65  if (j != i)
66  m_region_intersections[j].insert(k);
67  }
68  }
69 
70  m_region_inds_intersections.resize(num_factors);
71  m_msgs_from_region.resize(num_factors);
72  m_theta_region.resize(num_factors);
73 
74  for (int32_t c = 0; c < num_factors; c++)
75  {
76  CFactor* factor_c = dynamic_cast<CFactor*>(m_factors->get_element(c));
77  SGVector<int32_t> vars_c = factor_c->get_variables();
78  SG_UNREF(factor_c);
79 
81  m_msgs_from_region[c].resize(m_region_intersections[c].size());
82 
83  int32_t s = 0;
84 
85  for (set<int>::iterator t = m_region_intersections[c].begin();
86  t != m_region_intersections[c].end(); t++)
87  {
88  SGVector<int32_t> curr_intersection = m_all_intersections[*t];
89  SGVector<int32_t> inds_s(curr_intersection.size());
90  SGVector<int32_t> dims_array(curr_intersection.size());
91 
92  for (int32_t i = 0; i < inds_s.size(); i++)
93  {
94  inds_s[i] = vars_c.find(curr_intersection[i])[0];
95  REQUIRE(inds_s[i] >= 0,
96  "Intersection contains variable %d which is not in the region %d", curr_intersection[i], c);
97 
98  dims_array[i] = fg_var_sizes[curr_intersection[i]];
99  }
100 
101  // initialize indices of intersections inside the region
102  m_region_inds_intersections[c][s] = inds_s;
103 
104  // initialize messages from region and set it 0
105  SGNDArray<float64_t> message(dims_array);
106  message.set_const(0);
107  m_msgs_from_region[c][s] = message.clone();
108  s++;
109  }
110 
111  // initialize potential on region
113  }
114 
115  // initialize messages in intersections and set it 0
117 
118  for (uint32_t i = 0; i < m_all_intersections.size(); i++)
119  {
120  SGVector<int32_t> vars_intersection = m_all_intersections[i];
121  SGVector<int32_t> dims_array(vars_intersection.size());
122 
123  for (int32_t j = 0; j < dims_array.size(); j++)
124  dims_array[j] = fg_var_sizes[vars_intersection[j]];
125 
126  SGNDArray<float64_t> curr_array(dims_array);
127  curr_array.set_const(0);
128  m_msgs_into_intersections[i] = curr_array.clone();
129  }
130 }
131 
133 {
134  SGVector<float64_t> energies = factor->get_energies();
135  SGVector<int32_t> cards = factor->get_cardinalities();
136 
137  SGNDArray<float64_t> message(cards);
138 
139  if (cards.size() == 1)
140  {
141  for (int32_t i = 0; i < energies.size(); i++)
142  message.array[i] = - energies[i];
143  }
144  else if (cards.size() == 2)
145  {
146  for (int32_t y = 0; y < cards[1]; y++)
147  for (int32_t x = 0; x < cards[0]; x++)
148  message.array[x*cards[1]+y] = - energies[y*cards[0]+x];
149  }
150  else
151  SG_ERROR("Index issue has not been solved for higher order (>=3) factors.");
152 
153  return message.clone();
154 }
155 
157 {
158  vector<int32_t> tmp;
159 
160  for (int32_t i = 0; i < region_A.size(); i++)
161  {
162  for (int32_t j = 0; j < region_B.size(); j++)
163  {
164  if (region_A[i] == region_B[j])
165  tmp.push_back(region_A[i]);
166  }
167  }
168 
169  // return -1 if intersetion is empty
170  if (tmp.size() == 0) return -1;
171 
172 
173  SGVector<int32_t> sAB(tmp.size());
174  for (uint32_t i = 0; i < tmp.size(); i++)
175  sAB[i] = tmp[i];
176 
177  // find (or add) intersection set
178  int32_t k;
179  for (k = 0; k < (int32_t)m_all_intersections.size(); k++)
180  if (m_all_intersections[k].equals(sAB))
181  break;
182 
183  if (k == (int32_t)m_all_intersections.size())
184  m_all_intersections.push_back(sAB);
185 
186  return k;
187 }
188 
190 {
191  REQUIRE(assignment.size() == m_fg->get_cardinalities().size(),
192  "%s::inference(): the output assignment should be prepared as"
193  "the same size as variables!\n", get_name());
194 
195  // iterate over message loop
196  SG_SDEBUG("Running MPLP for %d iterations\n", m_param.m_max_iter);
197 
198  float64_t last_obj = CMath::INFTY;
199 
200  // block coordinate desent, outer loop
201  for (int32_t it = 0; it < m_param.m_max_iter; ++it)
202  {
203  // update message, iterate over all regions
204  for (int32_t c = 0; c < m_factors->get_num_elements(); c++)
205  {
206  CFactor* factor_c = dynamic_cast<CFactor*>(m_factors->get_element(c));
207  SGVector<int32_t> vars = factor_c->get_variables();
208  SG_UNREF(factor_c);
209 
210  if (vars.size() == 1 && it > 0)
211  continue;
212 
213  update_messages(c);
214  }
215 
216  // calculate the objective value
217  float64_t obj = 0;
218  int32_t max_at;
219 
220  for (uint32_t s = 0; s < m_msgs_into_intersections.size(); s++)
221  {
222  obj += m_msgs_into_intersections[s].max_element(max_at);
223 
224  if (m_all_intersections[s].size() == 1)
225  assignment[m_all_intersections[s][0]] = max_at;
226  }
227 
228  // get the value of the decoded solution
229  float64_t int_val = 0;
230 
231  // iterates over factors
232  for (int32_t c = 0; c < m_factors->get_num_elements(); c++)
233  {
234  CFactor* factor = dynamic_cast<CFactor*>(m_factors->get_element(c));
235  SGVector<int32_t> vars = factor->get_variables();
236  SGVector<int32_t> var_assignment(vars.size());
237 
238  for (int32_t i = 0; i < vars.size(); i++)
239  var_assignment[i] = assignment[vars[i]];
240 
241  // add value from current factor
242  int_val += m_theta_region[c].get_value(var_assignment);
243 
244  SG_UNREF(factor);
245  }
246 
247  float64_t obj_del = last_obj - obj;
248  float64_t int_gap = obj - int_val;
249 
250  SG_SDEBUG("Iter= %d Objective=%f ObjBest=%f ObjDel=%f Gap=%f \n", (it + 1), obj, int_val, obj_del, int_gap);
251 
252  if (obj_del < m_param.m_obj_del_thr && it > 16)
253  break;
254 
255  if (int_gap < m_param.m_int_gap_thr)
256  break;
257 
258  last_obj = obj;
259  }
260 
261  float64_t energy = m_fg->evaluate_energy(assignment);
262  SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", energy);
263 
264  return energy;
265 }
266 
267 void CGEMPLP::update_messages(int32_t id_region)
268 {
269  REQUIRE(m_factors != NULL, "Factors are not set!\n");
270 
271  REQUIRE(m_factors->get_num_elements() > id_region,
272  "Region id (%d) exceeds the factor elements' length (%d)!\n",
273  id_region, m_factors->get_num_elements());
274 
275  CFactor* factor = dynamic_cast<CFactor*>(m_factors->get_element(id_region));
276  SGVector<int32_t> vars = factor->get_variables();
277  SGVector<int32_t> cards = factor->get_cardinalities();
278  SGNDArray<float64_t> lam_sum(cards);
279 
280  if (m_theta_region[id_region].len_array == 0)
281  lam_sum.set_const(0);
282  else
283  lam_sum = m_theta_region[id_region].clone();
284 
285  int32_t num_intersections = m_region_intersections[id_region].size();
286  vector<SGNDArray<float64_t> > lam_minus; // substract message: \lambda_s^{-c}(x_s)
287  // \sum_{\hat{s}} \lambda_{\hat{s}}^{-c}(x_{\hat{s}}) + \theta_c(x_c)
288  int32_t s = 0;
289 
290  for (set<int32_t>::iterator t = m_region_intersections[id_region].begin();
291  t != m_region_intersections[id_region].end(); t++)
292  {
293  int32_t id_intersection = *t;
294  SGNDArray<float64_t> tmp = m_msgs_into_intersections[id_intersection].clone();
295  tmp -= m_msgs_from_region[id_region][s];
296 
297  lam_minus.push_back(tmp);
298 
299  if (vars.size() == (int32_t)m_region_inds_intersections[id_region][s].size())
300  lam_sum += tmp;
301  else
302  {
303  SGNDArray<float64_t> tmp_expand(lam_sum.get_dimensions());
304  tmp.expand(tmp_expand, m_region_inds_intersections[id_region][s]);
305  lam_sum += tmp_expand;
306  }
307 
308  // take out the old incoming message: \lambda_{c \to s}(x_s)
309  m_msgs_into_intersections[id_intersection] -= m_msgs_from_region[id_region][s];
310  s++;
311  }
312 
313  s = 0;
314 
315  for (set<int32_t>::iterator t = m_region_intersections[id_region].begin();
316  t != m_region_intersections[id_region].end(); t++)
317  {
318  // maximazation: \max_{x_c} \sum_{\hat{s}} \lambda_{\hat{s}}^{-c}(x_{\hat{s}}) + \theta_c(x_c)
319  SGNDArray<float64_t> lam_max(lam_minus[s].get_dimensions());
320  max_in_subdimension(lam_sum, m_region_inds_intersections[id_region][s], lam_max);
321  int32_t id_intersection = *t;
322  // weighted sum
323  lam_max *= 1.0/num_intersections;
324  m_msgs_from_region[id_region][s] = lam_max.clone();
325  m_msgs_from_region[id_region][s] -= lam_minus[s];
326  // put in new message
327  m_msgs_into_intersections[id_intersection] += m_msgs_from_region[id_region][s];
328  s++;
329  }
330 
331  SG_UNREF(factor);
332 }
333 
335 {
336  // If the subset equals the target array then maximizing would
337  // give us the target array (assuming there is no reordering)
338  if (subset_inds.size() == tar_arr.num_dims)
339  {
340  max_res = tar_arr.clone();
341  return;
342  }
343  else
344  max_res.set_const(-CMath::INFTY);
345 
346  // Go over all values of the target array. For each check if its
347  // value on the subset is larger than the current max
348  SGVector<int32_t> inds_for_tar(tar_arr.num_dims);
349  inds_for_tar.zero();
350 
351  for (int32_t vi = 0; vi < tar_arr.len_array; vi++)
352  {
353  int32_t y = 0;
354 
355  if (subset_inds.size() == 1)
356  y = inds_for_tar[subset_inds[0]];
357  else if (subset_inds.size() == 2)
358  {
359  int32_t ind1 = subset_inds[0];
360  int32_t ind2 = subset_inds[1];
361  y = inds_for_tar[ind1] * max_res.dims[1] + inds_for_tar[ind2];
362  }
363  max_res[y] = max(max_res[y], tar_arr.array[vi]);
364  tar_arr.next_index(inds_for_tar);
365  }
366 }
float64_t evaluate_energy(const SGVector< int32_t > state) const
float64_t m_int_gap_thr
Definition: GEMPLP.h:65
virtual ~CGEMPLP()
Definition: GEMPLP.cpp:33
static const float64_t INFTY
infinity
Definition: Math.h:2048
SGNDArray< T > clone() const
Definition: SGNDArray.cpp:105
const SGVector< int32_t > get_variables() const
Definition: Factor.cpp:107
CDynamicObjectArray * get_factors() const
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
Class CMAPInferImpl abstract class of MAP inference implementation.
Definition: MAPInference.h:98
vector< SGVector< int32_t > > m_all_intersections
Definition: GEMPLP.h:140
SGVector< float64_t > get_energies() const
Definition: Factor.cpp:169
int32_t size() const
Definition: SGVector.h:115
CFactorGraph * m_fg
Definition: MAPInference.h:128
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
Parameter m_param
Definition: GEMPLP.h:136
index_t num_dims
Definition: SGNDArray.h:180
vector< SGNDArray< float64_t > > m_theta_region
Definition: GEMPLP.h:150
virtual float64_t inference(SGVector< int32_t > assignment)
Definition: GEMPLP.cpp:189
vector< vector< SGNDArray< float64_t > > > m_msgs_from_region
Definition: GEMPLP.h:148
void max_in_subdimension(SGNDArray< float64_t > tar_arr, SGVector< int32_t > &subset_inds, SGNDArray< float64_t > &max_res) const
Definition: GEMPLP.cpp:334
virtual const char * get_name() const
Definition: GEMPLP.h:83
virtual bool equals(CSGObject *other, float64_t accuracy=0.0, bool tolerant=false)
Definition: SGObject.cpp:618
SGNDArray< float64_t > convert_energy_to_potential(CFactor *factor)
Definition: GEMPLP.cpp:132
void set_const(T const_elem)
Definition: SGNDArray.cpp:146
CDynamicObjectArray * m_factors
Definition: GEMPLP.h:138
#define SG_UNREF(x)
Definition: SGObject.h:52
index_t len_array
Definition: SGNDArray.h:183
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
int32_t find_intersection_index(SGVector< int32_t > region_A, SGVector< int32_t > region_B)
Definition: GEMPLP.cpp:156
#define SG_SDEBUG(...)
Definition: SGIO.h:168
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
index_t * dims
Definition: SGNDArray.h:177
CSGObject * get_element(int32_t index) const
vector< set< int32_t > > m_region_intersections
Definition: GEMPLP.h:142
Matrix::Scalar max(Matrix m)
Definition: Redux.h:66
vector< vector< SGVector< int32_t > > > m_region_inds_intersections
Definition: GEMPLP.h:144
SGVector< int32_t > get_cardinalities() const
vector< SGNDArray< float64_t > > m_msgs_into_intersections
Definition: GEMPLP.h:146
void next_index(SGVector< index_t > &curr_index) const
Definition: SGNDArray.cpp:286
SGVector< index_t > find(T elem)
Definition: SGVector.cpp:809
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...
Definition: Factor.h:89
const SGVector< int32_t > get_cardinalities() const
Definition: Factor.cpp:122
void expand(SGNDArray &big_array, SGVector< index_t > &axes)
Definition: SGNDArray.cpp:303

SHOGUN 机器学习工具包 - 项目文档