SHOGUN  6.1.3
ComputeMMD.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2012 - 2013 Heiko Strathmann
4  * Written (w) 2014 - 2017 Soumyajit De
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  */
31 
32 #ifndef COMPUTE_MMD_H_
33 #define COMPUTE_MMD_H_
34 
35 #include <array>
36 #include <vector>
37 #include <shogun/lib/config.h>
38 #include <shogun/lib/SGVector.h>
39 #include <shogun/lib/SGMatrix.h>
40 #include <shogun/kernel/Kernel.h>
45 #include <shogun/io/SGIO.h>
46 
47 namespace shogun
48 {
49 
50 namespace internal
51 {
52 
53 namespace mmd
54 {
55 
56 struct terms_t
57 {
58  std::array<float64_t, 3> term{};
59  std::array<float64_t, 3> diag{};
60 };
61 #ifndef DOXYGEN_SHOULD_SKIP_THIS
62 
65 struct ComputeMMD
66 {
67  ComputeMMD() : m_n_x(0), m_n_y(0), m_stype(ST_UNBIASED_FULL)
68  {
69  }
70 
71  template <class Kernel>
72  float32_t operator()(const Kernel& kernel) const
73  {
74  ASSERT(m_n_x>0 && m_n_y>0);
75  const index_t size=m_n_x+m_n_y;
76  terms_t terms;
77  for (auto i=0; i<size; ++i)
78  {
79  for (auto j=i; j<size; ++j)
80  add_term_upper(terms, kernel(i, j), i, j);
81  }
82  return compute(terms);
83  }
84 
85  template <typename T>
86  float32_t operator()(const SGMatrix<T>& kernel_matrix) const
87  {
88  ASSERT(m_n_x>0 && m_n_y>0);
89  const index_t size=m_n_x+m_n_y;
90  ASSERT(kernel_matrix.num_rows==size && kernel_matrix.num_cols==size);
91 
93  typedef Eigen::Block<Eigen::Map<const MatrixXt> > BlockXt;
94 
95  Eigen::Map<const MatrixXt> map(kernel_matrix.matrix, kernel_matrix.num_rows, kernel_matrix.num_cols);
96 
97  const BlockXt& b_x=map.block(0, 0, m_n_x, m_n_x);
98  const BlockXt& b_y=map.block(m_n_x, m_n_x, m_n_y, m_n_y);
99  const BlockXt& b_xy=map.block(m_n_x, 0, m_n_y, m_n_x);
100 
101  terms_t terms;
102  terms.diag[0]=b_x.diagonal().sum();
103  terms.diag[1]=b_y.diagonal().sum();
104  terms.diag[2]=b_xy.diagonal().sum();
105 
106  terms.term[0]=(b_x.sum()-terms.diag[0])/2+terms.diag[0];
107  terms.term[1]=(b_y.sum()-terms.diag[1])/2+terms.diag[1];
108  terms.term[2]=b_xy.sum();
109 
110  return compute(terms);
111  }
112 
113  SGVector<float64_t> operator()(const KernelManager& kernel_mgr) const
114  {
115  ASSERT(m_n_x>0 && m_n_y>0);
116  std::vector<terms_t> terms(kernel_mgr.num_kernels());
117  const index_t size=m_n_x+m_n_y;
118  for (auto j=0; j<size; ++j)
119  {
120  for (auto i=j; i<size; ++i)
121  {
122  for (auto k=0; k<kernel_mgr.num_kernels(); ++k)
123  {
124  auto kernel=kernel_mgr.kernel_at(k)->kernel(i, j);
125  add_term_lower(terms[k], kernel, i, j);
126  }
127  }
128  }
129 
130  SGVector<float64_t> result(kernel_mgr.num_kernels());
131  for (auto k=0; k<kernel_mgr.num_kernels(); ++k)
132  {
133  result[k]=compute(terms[k]);
134  SG_SDEBUG("result[%d] = %f!\n", k, result[k]);
135  }
136  terms.resize(0);
137  return result;
138  }
139 
149  template <typename T>
150  inline void add_term_lower(terms_t& terms, T kernel_value, index_t i, index_t j) const
151  {
152  ASSERT(m_n_x>0 && m_n_y>0);
153  if (i<m_n_x && j<m_n_x && i>=j)
154  {
155  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_0!\n", i, j, kernel_value);
156  terms.term[0]+=kernel_value;
157  if (i==j)
158  terms.diag[0]+=kernel_value;
159  }
160  else if (i>=m_n_x && j>=m_n_x && i>=j)
161  {
162  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_1!\n", i, j, kernel_value);
163  terms.term[1]+=kernel_value;
164  if (i==j)
165  terms.diag[1]+=kernel_value;
166  }
167  else if (i>=m_n_x && j<m_n_x)
168  {
169  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_2!\n", i, j, kernel_value);
170  terms.term[2]+=kernel_value;
171  if (i-m_n_x==j)
172  terms.diag[2]+=kernel_value;
173  }
174  }
175 
185  template <typename T>
186  inline void add_term_upper(terms_t& terms, T kernel_value, index_t i, index_t j) const
187  {
188  ASSERT(m_n_x>0 && m_n_y>0);
189  if (i<m_n_x && j<m_n_x && i<=j)
190  {
191  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_0!\n", i, j, kernel_value);
192  terms.term[0]+=kernel_value;
193  if (i==j)
194  terms.diag[0]+=kernel_value;
195  }
196  else if (i>=m_n_x && j>=m_n_x && i<=j)
197  {
198  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_1!\n", i, j, kernel_value);
199  terms.term[1]+=kernel_value;
200  if (i==j)
201  terms.diag[1]+=kernel_value;
202  }
203  else if (i<m_n_x && j>=m_n_x)
204  {
205  SG_SDEBUG("Adding Kernel(%d, %d)=%f to term_2!\n", i, j, kernel_value);
206  terms.term[2]+=kernel_value;
207  if (i+m_n_x==j)
208  terms.diag[2]+=kernel_value;
209  }
210  }
211 
212  inline float64_t compute(terms_t& terms) const
213  {
214  ASSERT(m_n_x>0 && m_n_y>0);
215  terms.term[0]=2*(terms.term[0]-terms.diag[0]);
216  terms.term[1]=2*(terms.term[1]-terms.diag[1]);
217  SG_SDEBUG("term_0 sum (without diagonal) = %f!\n", terms.term[0]);
218  SG_SDEBUG("term_1 sum (without diagonal) = %f!\n", terms.term[1]);
219  if (m_stype!=ST_BIASED_FULL)
220  {
221  terms.term[0]/=m_n_x*(m_n_x-1);
222  terms.term[1]/=m_n_y*(m_n_y-1);
223  }
224  else
225  {
226  terms.term[0]+=terms.diag[0];
227  terms.term[1]+=terms.diag[1];
228  SG_SDEBUG("term_0 sum (with diagonal) = %f!\n", terms.term[0]);
229  SG_SDEBUG("term_1 sum (with diagonal) = %f!\n", terms.term[1]);
230  terms.term[0]/=m_n_x*m_n_x;
231  terms.term[1]/=m_n_y*m_n_y;
232  }
233  SG_SDEBUG("term_0 (normalized) = %f!\n", terms.term[0]);
234  SG_SDEBUG("term_1 (normalized) = %f!\n", terms.term[1]);
235 
236  SG_SDEBUG("term_2 sum (with diagonal) = %f!\n", terms.term[2]);
237  if (m_stype==ST_UNBIASED_INCOMPLETE)
238  {
239  terms.term[2]-=terms.diag[2];
240  SG_SDEBUG("term_2 sum (without diagonal) = %f!\n", terms.term[2]);
241  terms.term[2]/=m_n_x*(m_n_x-1);
242  }
243  else
244  terms.term[2]/=m_n_x*m_n_y;
245  SG_SDEBUG("term_2 (normalized) = %f!\n", terms.term[2]);
246 
247  auto result=terms.term[0]+terms.term[1]-2*terms.term[2];
248  SG_SDEBUG("result = %f!\n", result);
249  return result;
250  }
251 
252  index_t m_n_x;
253  index_t m_n_y;
254  EStatisticType m_stype;
255 };
256 #endif // DOXYGEN_SHOULD_SKIP_THIS
257 }
258 
259 }
260 
261 }
262 #endif // COMPUTE_MMD_H_
int32_t index_t
Definition: common.h:72
#define ASSERT(x)
Definition: SGIO.h:176
double float64_t
Definition: common.h:60
std::array< float64_t, 3 > term
Definition: ComputeMMD.h:58
index_t num_rows
Definition: SGMatrix.h:495
EStatisticType
Definition: TestEnums.h:40
index_t num_cols
Definition: SGMatrix.h:497
float float32_t
Definition: common.h:59
shogun matrix
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:153
std::array< float64_t, 3 > diag
Definition: ComputeMMD.h:59

SHOGUN Machine Learning Toolbox - Documentation