SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MAPInference.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
14 
15 using namespace shogun;
16 
18 {
19  SG_UNSTABLE("CMAPInference::CMAPInference()", "\n");
20 
21  init();
22 }
23 
25  : CSGObject()
26 {
27  init();
28  m_fg = fg;
29 
30  REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name());
31 
32  switch(inference_method)
33  {
34  case TREE_MAX_PROD:
35  m_infer_impl = new CTreeMaxProduct(fg);
36  break;
37  case LOOPY_MAX_PROD:
38  SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n",
39  get_name());
40  break;
41  case LP_RELAXATION:
42  SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n",
43  get_name());
44  break;
45  case TRWS_MAX_PROD:
46  SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n",
47  get_name());
48  break;
49  case ITER_COND_MODE:
50  SG_ERROR("%s::CMAPInference(): ICM has not been implemented!\n",
51  get_name());
52  break;
53  case NAIVE_MEAN_FIELD:
54  SG_ERROR("%s::CMAPInference(): NaiveMeanField has not been implemented!\n",
55  get_name());
56  break;
57  case STRUCT_MEAN_FIELD:
58  SG_ERROR("%s::CMAPInference(): StructMeanField has not been implemented!\n",
59  get_name());
60  break;
61  default:
62  SG_ERROR("%s::CMAPInference(): unsupported inference method!\n",
63  get_name());
64  break;
65  }
66 
68  SG_REF(m_fg);
69 }
70 
72 {
75  SG_UNREF(m_fg);
76 }
77 
78 void CMAPInference::init()
79 {
80  SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE);
81  SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE);
82  SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE);
83  SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE);
84 
85  m_outputs = NULL;
86  m_infer_impl = NULL;
87  m_fg = NULL;
88  m_energy = 0;
89 }
90 
92 {
93  SGVector<int32_t> assignment(m_fg->get_num_vars());
94  assignment.zero();
95  m_energy = m_infer_impl->inference(assignment);
96 
97  // create structured output, with default normalized hamming loss
99  SGVector<float64_t> loss_weights(m_fg->get_num_vars());
100  SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
101  m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor
102  SG_REF(m_outputs);
103 }
104 
106 {
107  SG_REF(m_outputs);
108  return m_outputs;
109 }
110 
112 {
113  return m_energy;
114 }
115 
116 //-----------------------------------------------------------------
117 
119 {
120  register_parameters();
121 }
122 
124  : CSGObject()
125 {
126  register_parameters();
127  m_fg = fg;
128 }
129 
131 {
132 }
133 
134 void CMAPInferImpl::register_parameters()
135 {
136  SG_ADD((CSGObject**)&m_fg, "fg",
137  "Factor graph pointer", MS_NOT_AVAILABLE);
138 
139  m_fg = NULL;
140 }
141 

SHOGUN Machine Learning Toolbox - Documentation