SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MAPInference.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
15 
16 using namespace shogun;
17 
19 {
20  SG_UNSTABLE("CMAPInference::CMAPInference()", "\n");
21 
22  init();
23 }
24 
26  : CSGObject()
27 {
28  init();
29  m_fg = fg;
30 
31  REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name());
32 
33  switch(inference_method)
34  {
35  case TREE_MAX_PROD:
36  m_infer_impl = new CTreeMaxProduct(fg);
37  break;
38  case GRAPH_CUT:
39  m_infer_impl = new CGraphCut(fg);
40  break;
41  case LOOPY_MAX_PROD:
42  SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n",
43  get_name());
44  break;
45  case LP_RELAXATION:
46  SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n",
47  get_name());
48  break;
49  case TRWS_MAX_PROD:
50  SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n",
51  get_name());
52  break;
53  default:
54  SG_ERROR("%s::CMAPInference(): unsupported inference method!\n",
55  get_name());
56  break;
57  }
58 
60  SG_REF(m_fg);
61 }
62 
64 {
67  SG_UNREF(m_fg);
68 }
69 
70 void CMAPInference::init()
71 {
72  SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE);
73  SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE);
74  SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE);
75  SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE);
76 
77  m_outputs = NULL;
78  m_infer_impl = NULL;
79  m_fg = NULL;
80  m_energy = 0;
81 }
82 
84 {
85  SGVector<int32_t> assignment(m_fg->get_num_vars());
86  assignment.zero();
87  m_energy = m_infer_impl->inference(assignment);
88 
89  // create structured output, with default normalized hamming loss
91  SGVector<float64_t> loss_weights(m_fg->get_num_vars());
92  SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
93  m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor
95 }
96 
98 {
100  return m_outputs;
101 }
102 
104 {
105  return m_energy;
106 }
107 
108 //-----------------------------------------------------------------
109 
111 {
112  register_parameters();
113 }
114 
116  : CSGObject()
117 {
118  register_parameters();
119  m_fg = fg;
120 }
121 
123 {
124 }
125 
126 void CMAPInferImpl::register_parameters()
127 {
128  SG_ADD((CSGObject**)&m_fg, "fg",
129  "Factor graph pointer", MS_NOT_AVAILABLE);
130 
131  m_fg = NULL;
132 }
133 

SHOGUN Machine Learning Toolbox - Documentation