SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
MAPInference.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
16 
17 using namespace shogun;
18 
20 {
21  SG_UNSTABLE("CMAPInference::CMAPInference()", "\n");
22 
23  init();
24 }
25 
27  : CSGObject()
28 {
29  init();
30  m_fg = fg;
31 
32  REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name());
33 
34  switch(inference_method)
35  {
36  case TREE_MAX_PROD:
37  m_infer_impl = new CTreeMaxProduct(fg);
38  break;
39  case GRAPH_CUT:
40  m_infer_impl = new CGraphCut(fg);
41  break;
42  case GEMPLP:
43  m_infer_impl = new CGEMPLP(fg);
44  break;
45  case LOOPY_MAX_PROD:
46  SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n",
47  get_name());
48  break;
49  case LP_RELAXATION:
50  SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n",
51  get_name());
52  break;
53  case TRWS_MAX_PROD:
54  SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n",
55  get_name());
56  break;
57  default:
58  SG_ERROR("%s::CMAPInference(): unsupported inference method!\n",
59  get_name());
60  break;
61  }
62 
64  SG_REF(m_fg);
65 }
66 
68 {
71  SG_UNREF(m_fg);
72 }
73 
74 void CMAPInference::init()
75 {
76  SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE);
77  SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE);
78  SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE);
79  SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE);
80 
81  m_outputs = NULL;
82  m_infer_impl = NULL;
83  m_fg = NULL;
84  m_energy = 0;
85 }
86 
88 {
89  SGVector<int32_t> assignment(m_fg->get_num_vars());
90  assignment.zero();
91  m_energy = m_infer_impl->inference(assignment);
92 
93  // create structured output, with default normalized hamming loss
95  SGVector<float64_t> loss_weights(m_fg->get_num_vars());
96  SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
97  m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor
99 }
100 
102 {
103  SG_REF(m_outputs);
104  return m_outputs;
105 }
106 
108 {
109  return m_energy;
110 }
111 
112 //-----------------------------------------------------------------
113 
115 {
116  register_parameters();
117 }
118 
120  : CSGObject()
121 {
122  register_parameters();
123  m_fg = fg;
124 }
125 
127 {
128 }
129 
130 void CMAPInferImpl::register_parameters()
131 {
132  SG_ADD((CSGObject**)&m_fg, "fg",
133  "Factor graph pointer", MS_NOT_AVAILABLE);
134 
135  m_fg = NULL;
136 }
137 
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:221
virtual float64_t inference(SGVector< int32_t > assignment)=0
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
#define SG_REF(x)
Definition: SGObject.h:54
float64_t get_energy() const
CFactorGraph * m_fg
Definition: MAPInference.h:128
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
CFactorGraph * m_fg
Definition: MAPInference.h:83
int32_t get_num_vars() const
double float64_t
Definition: common.h:50
Class CFactorGraphObservation is used as the structured output.
CFactorGraphObservation * m_outputs
Definition: MAPInference.h:86
virtual void inference()
#define SG_UNREF(x)
Definition: SGObject.h:55
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
virtual const char * get_name() const
Definition: MAPInference.h:63
CFactorGraphObservation * get_structured_outputs() const
#define SG_ADD(...)
Definition: SGObject.h:84
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:132
CMAPInferImpl * m_infer_impl
Definition: MAPInference.h:92

SHOGUN Machine Learning Toolbox - Documentation