SHOGUN  6.1.3
CrossValidationMMD.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2016 - 2017 Soumyajit De
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
31 #ifndef CROSS_VALIDATION_MMD_H_
32 #define CROSS_VALIDATION_MMD_H_
33 
34 #include <memory>
35 #include <algorithm>
36 #include <numeric>
37 #include <shogun/lib/SGMatrix.h>
38 #include <shogun/lib/SGVector.h>
43 
44 using std::unique_ptr;
45 
46 namespace shogun
47 {
48 
49 namespace internal
50 {
51 
52 namespace mmd
53 {
54 #ifndef DOXYGEN_SHOULD_SKIP_THIS
55 struct CrossValidationMMD : PermutationMMD
56 {
57  CrossValidationMMD(index_t n_x, index_t n_y, index_t num_folds, index_t num_null_samples)
58  {
59  ASSERT(n_x>0 && n_y>0);
60  ASSERT(num_folds>0);
61  ASSERT(num_null_samples>0);
62 
63  m_n_x=n_x;
64  m_n_y=n_y;
65  m_num_folds=num_folds;
66  m_num_null_samples=num_null_samples;
67  m_num_runs=DEFAULT_NUM_RUNS;
68  m_alpha=DEFAULT_ALPHA;
69 
70  init();
71  }
72 
73  void operator()(const KernelManager& kernel_mgr)
74  {
75  REQUIRE(m_rejections.num_rows==m_num_runs*m_num_folds,
76  "Number of rows in the measure matrix (was %d), has to be >= %d*%d = %d!\n",
77  m_rejections.num_rows, m_num_runs, m_num_folds, m_num_runs*m_num_folds);
78  REQUIRE(m_rejections.num_cols==kernel_mgr.num_kernels(),
79  "Number of columns in the measure matrix (was %d), has to equal to the nunber of kernels (%d)!\n",
80  m_rejections.num_cols, kernel_mgr.num_kernels());
81 
82  const index_t size=m_n_x+m_n_y;
83  const index_t orig_n_x=m_n_x;
84  const index_t orig_n_y=m_n_y;
85  SGVector<float64_t> null_samples(m_num_null_samples);
86  SGVector<float32_t> precomputed_km(size*(size+1)/2);
87 
88  for (auto k=0; k<kernel_mgr.num_kernels(); ++k)
89  {
90  auto kernel=kernel_mgr.kernel_at(k);
91  for (auto i=0; i<size; ++i)
92  {
93  for (auto j=i; j<size; ++j)
94  {
95  auto index=i*size-i*(i+1)/2+j;
96  precomputed_km[index]=kernel->kernel(i, j);
97  }
98  }
99 
100  for (auto current_run=0; current_run<m_num_runs; ++current_run)
101  {
102  m_kfold_x->build_subsets();
103  m_kfold_y->build_subsets();
104  for (auto current_fold=0; current_fold<m_num_folds; ++current_fold)
105  {
106  generate_inds(current_fold);
107  std::fill(m_inverted_inds.data(), m_inverted_inds.data()+m_inverted_inds.size(), -1);
108  for (index_t idx=0; idx<m_xy_inds.size(); ++idx)
109  m_inverted_inds[m_xy_inds[idx]]=idx;
110 
111  m_stack->add_subset(m_xy_inds);
112 
113  if (m_permuted_inds.size()!=m_xy_inds.size())
114  m_permuted_inds=SGVector<index_t>(m_xy_inds.size());
115 
116  m_inverted_permuted_inds.set_const(-1);
117 
118  for (auto n=0; n<m_num_null_samples; ++n)
119  {
120  std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0);
121  CMath::permute(m_permuted_inds);
122 
123  m_stack->add_subset(m_permuted_inds);
124  SGVector<index_t> inds=m_stack->get_last_subset()->get_subset_idx();
125  m_stack->remove_subset();
126 
127  for (int idx=0; idx<inds.size(); ++idx)
128  m_inverted_permuted_inds(inds[idx], n)=idx;
129  }
130  m_stack->remove_subset();
131 
132  terms_t terms;
133  for (auto i=0; i<size; ++i)
134  {
135  auto inverted_row=m_inverted_inds[i];
136  auto idx_base=i*size-i*(i+1)/2;
137  for (auto j=i; j<size; ++j)
138  {
139  auto inverted_col=m_inverted_inds[j];
140  if (inverted_row!=-1 && inverted_col!=-1)
141  {
142  auto idx=idx_base+j;
143  add_term_upper(terms, precomputed_km[idx], inverted_row, inverted_col);
144  }
145  }
146  }
147  auto statistic=compute(terms);
148 
149 #pragma omp parallel for
150  for (auto n=0; n<m_num_null_samples; ++n)
151  {
152  terms_t null_terms;
153  for (auto i=0; i<size; ++i)
154  {
155  auto inverted_row=m_inverted_permuted_inds(i, n);
156  auto idx_base=i*size-i*(i+1)/2;
157  for (auto j=i; j<size; ++j)
158  {
159  auto inverted_col=m_inverted_permuted_inds(j, n);
160  if (inverted_row!=-1 && inverted_col!=-1)
161  {
162  auto idx=idx_base+j;
163  if (inverted_row<=inverted_col)
164  add_term_upper(null_terms, precomputed_km[idx], inverted_row, inverted_col);
165  else
166  add_term_upper(null_terms, precomputed_km[idx], inverted_col, inverted_row);
167  }
168  }
169  }
170  null_samples[n]=compute(null_terms);
171  }
172 
173  std::sort(null_samples.data(), null_samples.data()+null_samples.size());
174  SG_SDEBUG("statistic=%f\n", statistic);
175  float64_t idx=null_samples.find_position_to_insert(statistic);
176  SG_SDEBUG("index=%f\n", idx);
177  auto p_value=1.0-idx/m_num_null_samples;
178  bool rejected=p_value<m_alpha;
179  SG_SDEBUG("p-value=%f, alpha=%f, rejected=%d\n", p_value, m_alpha, rejected);
180  m_rejections(current_run*m_num_folds+current_fold, k)=rejected;
181 
182  m_n_x=orig_n_x;
183  m_n_y=orig_n_y;
184  }
185  }
186  }
187  }
188 
189  void init()
190  {
191  SGVector<int64_t> dummy_labels_x(m_n_x);
192  SGVector<int64_t> dummy_labels_y(m_n_y);
193 
194  auto instance_x=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_x), m_num_folds);
195  auto instance_y=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_y), m_num_folds);
196  m_kfold_x=unique_ptr<CCrossValidationSplitting>(instance_x);
197  m_kfold_y=unique_ptr<CCrossValidationSplitting>(instance_y);
198 
199  m_stack=unique_ptr<CSubsetStack>(new CSubsetStack());
200 
201  const index_t size=m_n_x+m_n_y;
202  m_inverted_inds=SGVector<index_t>(size);
203  m_inverted_permuted_inds=SGMatrix<index_t>(size, m_num_null_samples);
204  }
205 
206  void generate_inds(index_t current_fold)
207  {
208  SGVector<index_t> x_inds=m_kfold_x->generate_subset_inverse(current_fold);
209  SGVector<index_t> y_inds=m_kfold_y->generate_subset_inverse(current_fold);
210  std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [this](index_t& val) { val += m_n_x; });
211 
212  m_n_x=x_inds.size();
213  m_n_y=y_inds.size();
214 
215  if (m_xy_inds.size()!=m_n_x+m_n_y)
216  m_xy_inds=SGVector<index_t>(m_n_x+m_n_y);
217 
218  std::copy(x_inds.data(), x_inds.data()+x_inds.size(), m_xy_inds.data());
219  std::copy(y_inds.data(), y_inds.data()+y_inds.size(), m_xy_inds.data()+x_inds.size());
220  }
221 
222  index_t m_num_runs;
223  index_t m_num_folds;
224  static constexpr index_t DEFAULT_NUM_RUNS=10;
225 
226  float64_t m_alpha;
227  static constexpr float64_t DEFAULT_ALPHA=0.05;
228 
229  unique_ptr<CCrossValidationSplitting> m_kfold_x;
230  unique_ptr<CCrossValidationSplitting> m_kfold_y;
231  unique_ptr<CSubsetStack> m_stack;
232 
233  SGVector<index_t> m_xy_inds;
234  SGVector<index_t> m_inverted_inds;
235  SGMatrix<float64_t> m_rejections;
236 
237 };
238 #endif // DOXYGEN_SHOULD_SKIP_THIS
239 }
240 
241 }
242 
243 }
244 #endif // CROSS_VALIDATION_MMD_H_
static void permute(SGVector< T > v, CRandom *rand=NULL)
Definition: Math.h:962
int32_t index_t
Definition: common.h:72
#define REQUIRE(x,...)
Definition: SGIO.h:181
#define ASSERT(x)
Definition: SGIO.h:176
double float64_t
Definition: common.h:60
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:153

SHOGUN Machine Learning Toolbox - Documentation