SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
FactorGraph.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
13 
14 using namespace shogun;
15 
17  : CSGObject()
18 {
19  SG_UNSTABLE("CFactorGraph::CFactorGraph()", "\n");
20 
21  register_parameters();
22  init();
23 }
24 
26  : CSGObject()
27 {
28  m_cardinalities = card;
29  register_parameters();
30  init();
31 }
32 
34  : CSGObject()
35 {
36  register_parameters();
38  // No need to unref and ref in this case
39  m_factors = fg.get_factors();
41  m_dset = fg.get_disjoint_set();
42  m_has_cycle = !(fg.is_acyclic_graph());
44 }
45 
47 {
51 
52 #ifdef USE_REFERENCE_COUNTING
53  if (m_factors != NULL)
54  SG_DEBUG("CFactorGraph::~CFactorGraph(): m_factors->ref_count() = %d.\n", m_factors->ref_count());
55 
56  if (m_datasources != NULL)
57  SG_DEBUG("CFactorGraph::~CFactorGraph(): m_datasources->ref_count() = %d.\n", m_datasources->ref_count());
58 
59  SG_DEBUG("CFactorGraph::~CFactorGraph(): this->ref_count() = %d.\n", this->ref_count());
60 #endif
61 }
62 
63 void CFactorGraph::register_parameters()
64 {
65  SG_ADD(&m_cardinalities, "cardinalities", "Cardinalities", MS_NOT_AVAILABLE);
66  SG_ADD((CSGObject**)&m_factors, "factors", "Factors", MS_NOT_AVAILABLE);
67  SG_ADD((CSGObject**)&m_datasources, "datasources", "Factor data sources", MS_NOT_AVAILABLE);
68  SG_ADD((CSGObject**)&m_dset, "dset", "Disjoint set", MS_NOT_AVAILABLE);
69  SG_ADD(&m_has_cycle, "has_cycle", "Whether has circle in graph", MS_NOT_AVAILABLE);
70  SG_ADD(&m_num_edges, "num_edges", "Number of edges", MS_NOT_AVAILABLE);
71 }
72 
73 void CFactorGraph::init()
74 {
75  m_has_cycle = false;
76  m_num_edges = 0;
77  m_factors = NULL;
78  m_datasources = NULL;
81 
82 #ifdef USE_REFERENCE_COUNTING
83  if (m_factors != NULL)
84  SG_DEBUG("CFactorGraph::init(): m_factors->ref_count() = %d.\n", m_factors->ref_count());
85 #endif
86 
87  // NOTE m_cards cannot be empty
89 
92  SG_REF(m_dset);
93 }
94 
96 {
97  m_factors->push_back(factor);
98  m_num_edges += factor->get_variables().size();
99 
100  // graph structure changed after adding factors
101  if (m_dset->get_connected())
102  m_dset->set_connected(false);
103 }
104 
106 {
107  m_datasources->push_back(datasource);
108 }
109 
111 {
112  SG_REF(m_factors);
113  return m_factors;
114 }
115 
117 {
119  return m_datasources;
120 }
121 
123 {
124  return m_factors->get_num_elements();
125 }
126 
128 {
129  return m_cardinalities;
130 }
131 
133 {
134  m_cardinalities = cards.clone();
135 }
136 
138 {
139  SG_REF(m_dset);
140  return m_dset;
141 }
142 
144 {
145  return m_num_edges;
146 }
147 
149 {
150  return m_cardinalities.size();
151 }
152 
154 {
155  for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
156  {
157  CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
158  fac->compute_energies();
159  SG_UNREF(fac);
160  }
161 }
162 
164 {
165  ASSERT(state.size() == m_cardinalities.size());
166 
167  float64_t energy = 0.0;
168  for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
169  {
170  CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
171  energy += fac->evaluate_energy(state);
172  SG_UNREF(fac);
173  }
174  return energy;
175 }
176 
178 {
179  return evaluate_energy(obs->get_data());
180 }
181 
183 {
184  int num_assig = 1;
185  SGVector<int32_t> cumprod_cards(m_cardinalities.size());
186  for (int32_t n = 0; n < m_cardinalities.size(); ++n)
187  {
188  cumprod_cards[n] = num_assig;
189  num_assig *= m_cardinalities[n];
190  }
191 
192  SGVector<float64_t> etable(num_assig);
193  for (int32_t ei = 0; ei < num_assig; ++ei)
194  {
196  for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi)
197  assig[vi] = (ei / cumprod_cards[vi]) % m_cardinalities[vi];
198 
199  etable[ei] = evaluate_energy(assig);
200 
201  for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi)
202  SG_SPRINT("%d ", assig[vi]);
203 
204  SG_SPRINT("| %f\n", etable[ei]);
205  }
206 
207  return etable;
208 }
209 
211 {
212  if (m_dset->get_connected())
213  return;
214 
215  // need to be reset once factor graph is updated
216  m_dset->make_sets();
217  bool flag = false;
218 
219  for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
220  {
221  CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
222  SGVector<int32_t> vars = fac->get_variables();
223 
224  int32_t r0 = m_dset->find_set(vars[0]);
225  for (int32_t vi = 1; vi < vars.size(); vi++)
226  {
227  // for two nodes in a factor, should be an edge between them
228  // but this time link() isn't performed, if they are linked already
229  // means there is another path connected them, so cycle detected
230  int32_t ri = m_dset->find_set(vars[vi]);
231 
232  if (r0 == ri)
233  {
234  flag = true;
235  continue;
236  }
237 
238  r0 = m_dset->link_set(r0, ri);
239  }
240 
241  SG_UNREF(fac);
242  }
243  m_has_cycle = flag;
244  m_dset->set_connected(true);
245 }
246 
248 {
249  return !m_has_cycle;
250 }
251 
253 {
254  return (m_dset->get_num_sets() == 1);
255 }
256 
258 {
259  return (m_has_cycle == false && m_dset->get_num_sets() == 1);
260 }
261 
263 {
265 }
266 
268 {
269  if (loss.size() == 0)
270  {
271  loss.resize_vector(states_gt.size());
272  SGVector<float64_t>::fill_vector(loss.vector, loss.vlen, 1.0 / states_gt.size());
273  }
274 
275  int32_t num_vars = states_gt.size();
276  ASSERT(num_vars == loss.size());
277 
278  SGVector<int32_t> var_flags(num_vars);
279  var_flags.zero();
280 
281  // augment loss to incorrect states in the first factor containing the variable
282  // since += L_i for each variable if it takes wrong state ever
283  // TODO: augment unary factors
284  for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
285  {
286  CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
287  SGVector<int32_t> vars = fac->get_variables();
288  for (int32_t vi = 0; vi < vars.size(); vi++)
289  {
290  int32_t vv = vars[vi];
291  ASSERT(vv < num_vars);
292  if (var_flags[vv])
293  continue;
294 
295  SGVector<float64_t> energies = fac->get_energies();
296  for (int32_t ei = 0; ei < energies.size(); ei++)
297  {
298  CTableFactorType* ftype = fac->get_factor_type();
299  int32_t vstate = ftype->state_from_index(ei, vi);
300  SG_UNREF(ftype);
301 
302  if (states_gt[vv] == vstate)
303  continue;
304 
305  // -delta(y_n, y_i_n)
306  fac->set_energy(ei, energies[ei] - loss[vv]);
307  }
308 
309  var_flags[vv] = 1;
310  }
311 
312  SG_UNREF(fac);
313  }
314 
315  // make sure all variables have been checked
316  int32_t min_var = CMath::min(var_flags.vector, var_flags.vlen);
317  ASSERT(min_var == 1);
318 }
319 
void set_energy(int32_t ei, float64_t value)
Definition: Factor.cpp:193
float64_t evaluate_energy(const SGVector< int32_t > state) const
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:223
SGVector< int32_t > get_data() const
void add_factor(CFactor *factor)
Definition: FactorGraph.cpp:95
void add_data_source(CFactorDataSource *datasource)
CTableFactorType * get_factor_type() const
Definition: Factor.cpp:95
int32_t get_num_edges() const
bool is_tree_graph() const
const SGVector< int32_t > get_variables() const
Definition: Factor.cpp:107
SGVector< float64_t > get_loss_weights() const
int32_t get_num_factors() const
void compute_energies()
Definition: Factor.cpp:210
CDynamicObjectArray * get_factors() const
float64_t evaluate_energy(const SGVector< int32_t > state) const
Definition: Factor.cpp:204
CDisjointSet * get_disjoint_set() const
#define SG_REF(x)
Definition: SGObject.h:51
SGVector< float64_t > get_energies() const
Definition: Factor.cpp:169
CDynamicObjectArray * m_datasources
Definition: FactorGraph.h:154
void set_cardinalities(SGVector< int32_t > cards)
int32_t size() const
Definition: SGVector.h:115
index_t vlen
Definition: SGVector.h:494
#define SG_SPRINT(...)
Definition: SGIO.h:180
#define ASSERT(x)
Definition: SGIO.h:201
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
CDynamicObjectArray * get_factor_data_sources() const
int32_t get_num_vars() const
double float64_t
Definition: common.h:50
Class CFactorGraphObservation is used as the structured output.
Class CFactorDataSource Source for factor data. In some cases, the same data can be shared by many fa...
Definition: Factor.h:27
CDisjointSet * m_dset
Definition: FactorGraph.h:157
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
Class CDisjointSet data structure for linking graph nodes It's easy to identify connected graph...
Definition: DisjointSet.h:26
SGVector< float64_t > evaluate_energies() const
int32_t state_from_index(int32_t ei, int32_t var_index) const
Definition: FactorType.cpp:155
#define SG_UNREF(x)
Definition: SGObject.h:52
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
CDynamicObjectArray * m_factors
Definition: FactorGraph.h:151
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
static T min(T a, T b)
Definition: Math.h:157
Class CTableFactorType the way that store assignments of variables and energies in a table or a multi...
Definition: FactorType.h:122
bool is_acyclic_graph() const
SGVector< T > clone() const
Definition: SGVector.cpp:209
CSGObject * get_element(int32_t index) const
void resize_vector(int32_t n)
Definition: SGVector.cpp:259
SGVector< int32_t > get_cardinalities() const
virtual void loss_augmentation(CFactorGraphObservation *gt)
#define SG_ADD(...)
Definition: SGObject.h:81
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:132
SGVector< int32_t > m_cardinalities
Definition: FactorGraph.h:148
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...
Definition: Factor.h:89
bool is_connected_graph() const

SHOGUN 机器学习工具包 - 项目文档