SHOGUN  6.1.3
PermutationMMD.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2012 - 2013 Heiko Strathmann
4  * Written (w) 2014 - 2017 Soumyajit De
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  */
31 
32 #ifndef PERMUTATION_MMD_H_
33 #define PERMUTATION_MMD_H_
34 
35 #include <algorithm>
36 #include <numeric>
37 #include <shogun/lib/SGVector.h>
38 #include <shogun/lib/SGMatrix.h>
41 
42 namespace shogun
43 {
44 
45 namespace internal
46 {
47 
48 namespace mmd
49 {
50 #ifndef DOXYGEN_SHOULD_SKIP_THIS
51 struct PermutationMMD : ComputeMMD
52 {
53  PermutationMMD() : m_save_inds(false)
54  {
55  }
56 
57  template <class Kernel>
58  SGVector<float32_t> operator()(const Kernel& kernel)
59  {
60  ASSERT(m_n_x>0 && m_n_y>0);
61  ASSERT(m_num_null_samples>0);
62  precompute_permutation_inds();
63 
64  const index_t size=m_n_x+m_n_y;
65  SGVector<float32_t> null_samples(m_num_null_samples);
66 #pragma omp parallel for
67  for (auto n=0; n<m_num_null_samples; ++n)
68  {
69  terms_t terms;
70  for (auto j=0; j<size; ++j)
71  {
72  auto inverted_col=m_inverted_permuted_inds(j, n);
73  for (auto i=j; i<size; ++i)
74  {
75  auto inverted_row=m_inverted_permuted_inds(i, n);
76 
77  if (inverted_row>=inverted_col)
78  add_term_lower(terms, kernel(i, j), inverted_row, inverted_col);
79  else
80  add_term_lower(terms, kernel(i, j), inverted_col, inverted_row);
81  }
82  }
83  null_samples[n]=compute(terms);
84  SG_SDEBUG("null_samples[%d] = %f!\n", n, null_samples[n]);
85  }
86  return null_samples;
87  }
88 
89  SGMatrix<float32_t> operator()(const KernelManager& kernel_mgr)
90  {
91  ASSERT(m_n_x>0 && m_n_y>0);
92  ASSERT(m_num_null_samples>0);
93  precompute_permutation_inds();
94 
95  const index_t size=m_n_x+m_n_y;
96  SGMatrix<float32_t> null_samples(m_num_null_samples, kernel_mgr.num_kernels());
97  SGVector<float32_t> km(size*(size+1)/2);
98  for (auto k=0; k<kernel_mgr.num_kernels(); ++k)
99  {
100  auto kernel=kernel_mgr.kernel_at(k);
101  terms_t terms;
102  for (auto i=0; i<size; ++i)
103  {
104  for (auto j=i; j<size; ++j)
105  {
106  auto index=i*size-i*(i+1)/2+j;
107  km[index]=kernel->kernel(i, j);
108  }
109  }
110 
111 #pragma omp parallel for
112  for (auto n=0; n<m_num_null_samples; ++n)
113  {
114  terms_t null_terms;
115  for (auto i=0; i<size; ++i)
116  {
117  auto inverted_row=m_inverted_permuted_inds(i, n);
118  auto index_base=i*size-i*(i+1)/2;
119  for (auto j=i; j<size; ++j)
120  {
121  auto index=index_base+j;
122  auto inverted_col=m_inverted_permuted_inds(j, n);
123 
124  if (inverted_row<=inverted_col)
125  add_term_upper(null_terms, km[index], inverted_row, inverted_col);
126  else
127  add_term_upper(null_terms, km[index], inverted_col, inverted_row);
128  }
129  }
130  null_samples(n, k)=compute(null_terms);
131  }
132  }
133  return null_samples;
134  }
135 
136  template <class Kernel>
137  float64_t p_value(const Kernel& kernel)
138  {
139  auto statistic=ComputeMMD::operator()(kernel);
140  auto null_samples=operator()(kernel);
141  return compute_p_value(null_samples, statistic);
142  }
143 
144  SGVector<float64_t> p_value(const KernelManager& kernel_mgr)
145  {
146  ASSERT(m_n_x>0 && m_n_y>0);
147  ASSERT(m_num_null_samples>0);
148  precompute_permutation_inds();
149 
150  const index_t size=m_n_x+m_n_y;
151  SGVector<float32_t> null_samples(m_num_null_samples);
152  SGVector<float64_t> result(kernel_mgr.num_kernels());
153 
154  SGVector<float32_t> km(size*(size+1)/2);
155  for (auto k=0; k<kernel_mgr.num_kernels(); ++k)
156  {
157  auto kernel=kernel_mgr.kernel_at(k);
158  terms_t terms;
159  for (auto i=0; i<size; ++i)
160  {
161  for (auto j=i; j<size; ++j)
162  {
163  auto index=i*size-i*(i+1)/2+j;
164  km[index]=kernel->kernel(i, j);
165  add_term_upper(terms, km[index], i, j);
166  }
167  }
168  float32_t statistic=compute(terms);
169  SG_SDEBUG("Kernel(%d): statistic=%f\n", k, statistic);
170 
171 #pragma omp parallel for
172  for (auto n=0; n<m_num_null_samples; ++n)
173  {
174  terms_t null_terms;
175  for (auto i=0; i<size; ++i)
176  {
177  auto inverted_row=m_inverted_permuted_inds(i, n);
178  auto index_base=i*size-i*(i+1)/2;
179  for (auto j=i; j<size; ++j)
180  {
181  auto index=index_base+j;
182  auto inverted_col=m_inverted_permuted_inds(j, n);
183 
184  if (inverted_row<=inverted_col)
185  add_term_upper(null_terms, km[index], inverted_row, inverted_col);
186  else
187  add_term_upper(null_terms, km[index], inverted_col, inverted_row);
188  }
189  }
190  null_samples[n]=compute(null_terms);
191  }
192  result[k]=compute_p_value(null_samples, statistic);
193  SG_SDEBUG("Kernel(%d): p_value=%f\n", k, result[k]);
194  }
195 
196  return result;
197  }
198 
199  inline void precompute_permutation_inds()
200  {
201  ASSERT(m_num_null_samples>0);
202  allocate_permutation_inds();
203  for (auto n=0; n<m_num_null_samples; ++n)
204  {
205  std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0);
206  CMath::permute(m_permuted_inds);
207  if (m_save_inds)
208  {
209  auto offset=n*m_permuted_inds.size();
210  std::copy(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), &m_all_inds.matrix[offset]);
211  }
212  for (index_t i=0; i<m_permuted_inds.size(); ++i)
213  m_inverted_permuted_inds(m_permuted_inds[i], n)=i;
214  }
215  }
216 
217  inline float64_t compute_p_value(SGVector<float32_t>& null_samples, float32_t statistic) const
218  {
219  std::sort(null_samples.data(), null_samples.data()+null_samples.size());
220  float64_t idx=null_samples.find_position_to_insert(statistic);
221  return 1.0-idx/null_samples.size();
222  }
223 
224  inline void allocate_permutation_inds()
225  {
226  const index_t size=m_n_x+m_n_y;
227  if (m_permuted_inds.size()!=size)
228  m_permuted_inds=SGVector<index_t>(size);
229 
230  if (m_inverted_permuted_inds.num_cols!=m_num_null_samples || m_inverted_permuted_inds.num_rows!=size)
231  m_inverted_permuted_inds=SGMatrix<index_t>(size, m_num_null_samples);
232 
233  if (m_save_inds && (m_all_inds.num_cols!=m_num_null_samples || m_all_inds.num_rows!=size))
234  m_all_inds=SGMatrix<index_t>(size, m_num_null_samples);
235  }
236 
237  index_t m_num_null_samples;
238  bool m_save_inds;
239  SGVector<index_t> m_permuted_inds;
240  SGMatrix<index_t> m_inverted_permuted_inds;
241  SGMatrix<index_t> m_all_inds;
242 };
243 #endif // DOXYGEN_SHOULD_SKIP_THIS
244 }
245 
246 }
247 
248 }
249 
250 #endif // PERMUTATION_MMD_H_
static void permute(SGVector< T > v, CRandom *rand=NULL)
Definition: Math.h:962
int32_t index_t
Definition: common.h:72
#define ASSERT(x)
Definition: SGIO.h:176
double float64_t
Definition: common.h:60
float float32_t
Definition: common.h:59
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:153

SHOGUN Machine Learning Toolbox - Documentation