SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
NbodyTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
33 
34 using namespace shogun;
35 
38 {
39  init();
40 
41  m_leaf_size=leaf_size;
42  m_dist=d;
43 }
44 
46 {
47  REQUIRE(data,"data not set\n");
48  REQUIRE(m_leaf_size>0,"Leaf size should be greater than 0\n");
49 
50  m_knn_done=false;
51  m_data=data->get_feature_matrix();
52 
55 
56  set_root(recursive_build(0,m_data.num_cols-1));
57 }
58 
60 {
61  REQUIRE(data,"Query data not supplied\n")
62  REQUIRE(data->get_num_features()==m_data.num_rows,"query data dimension should be same as training data dimension\n")
63 
64  m_knn_done=true;
66  m_knn_dists=SGMatrix<float64_t>(k,qfeats.num_cols);
67  m_knn_indices=SGMatrix<index_t>(k,qfeats.num_cols);
68  int32_t dim=qfeats.num_rows;
69 
70  for (int32_t i=0;i<qfeats.num_cols;i++)
71  {
72  CKNNHeap* heap=new CKNNHeap(k);
73  bnode_t* root=NULL;
74  if (m_root)
75  root=dynamic_cast<bnode_t*>(m_root);
76 
77  float64_t mdist=min_dist(root,qfeats.matrix+i*dim,dim);
78  query_knn_single(heap,mdist,root,qfeats.matrix+i*dim,dim);
79  memcpy(m_knn_dists.matrix+i*k,heap->get_dists(),k*sizeof(float64_t));
80  memcpy(m_knn_indices.matrix+i*k,heap->get_indices(),k*sizeof(index_t));
81 
82  delete(heap);
83  }
84 }
85 
87 {
88  int32_t dim=m_data.num_rows;
89  REQUIRE(test.num_rows==dim,"dimensions of training data and test data should be the same\n")
90 
91  float64_t log_atol=CMath::log(atol*m_data.num_cols);
92  float64_t log_rtol=CMath::log(rtol);
93  float64_t log_kernel_norm=CKernelDensity::log_norm(kernel,h,dim);
94  SGVector<float64_t> log_density(test.num_cols);
95  for (int32_t i=0;i<test.num_cols;i++)
96  {
97  bnode_t* root=NULL;
98  if (m_root)
99  root=dynamic_cast<bnode_t*>(m_root);
100 
101  float64_t lower_dist=0;
102  float64_t upper_dist=0;
103  min_max_dist(test.matrix+i*dim,root,lower_dist,upper_dist,dim);
104 
105  float64_t min_bound=CMath::log(m_data.num_cols)+CKernelDensity::log_kernel(kernel,upper_dist,h);
106  float64_t max_bound=CMath::log(m_data.num_cols)+CKernelDensity::log_kernel(kernel,lower_dist,h);
107  float64_t spread=logdiffexp(max_bound,min_bound);
108 
109  get_kde_single(root,test.matrix+i*dim,kernel,h,log_atol,log_rtol,log_kernel_norm,min_bound,spread,min_bound,spread);
110  log_density[i]=logsumexp(min_bound,spread-CMath::log(2))+log_kernel_norm-CMath::log(m_data.num_cols);
111  }
112 
113  return log_density;
114 }
115 
117 {
118  int32_t dim=m_data.num_rows;
119  REQUIRE(test.num_rows==dim,"dimensions of training data and test data should be the same\n")
120 
121  float64_t log_atol=CMath::log(atol*m_data.num_cols*test.num_cols);
122  float64_t log_rtol=CMath::log(rtol);
123  float64_t log_kernel_norm=CKernelDensity::log_norm(kernel,h,dim);
124  SGVector<float64_t> log_density(test.num_cols);
125  log_density.fill_vector(log_density.vector,log_density.vlen,-CMath::INFTY);
126 
127  bnode_t* rroot=NULL;
128  if (m_root)
129  rroot=dynamic_cast<bnode_t*>(m_root);
130 
131  float64_t upper_dist=max_dist_dual(rroot,qroot);
132  float64_t lower_dist=min_dist_dual(rroot,qroot);
133  float64_t min_bound=CMath::log(test.num_cols)+CMath::log(m_data.num_cols)+CKernelDensity::log_kernel(kernel,upper_dist,h);
134  float64_t max_bound=CMath::log(test.num_cols)+CMath::log(m_data.num_cols)+CKernelDensity::log_kernel(kernel,lower_dist,h);
135  float64_t spread=logdiffexp(max_bound,min_bound);
136 
137  kde_dual(rroot,qroot,qid,test,log_density,kernel,h,log_atol,log_rtol,log_kernel_norm,min_bound,spread,min_bound,spread);
138 
140  for (int32_t i=0;i<test.num_cols;i++)
141  log_density[i]=log_density[i]+log_kernel_norm-log_n;
142 
143  return log_density;
144 }
145 
147 {
148  if (m_knn_done)
149  return m_knn_dists;
150 
151  SG_ERROR("knn query has not been executed yet\n");
152  return SGMatrix<float64_t>();
153 }
154 
156 {
157  if (m_knn_done)
158  return m_knn_indices;
159 
160  SG_ERROR("knn query has not been executed yet\n");
161  return SGMatrix<index_t>();
162 }
163 
164 void CNbodyTree::query_knn_single(CKNNHeap* heap, float64_t mdist, bnode_t* node, float64_t* arr, int32_t dim)
165 {
166  if (mdist>heap->get_max_dist())
167  return;
168 
169  if (node->data.is_leaf)
170  {
171  index_t start=node->data.start_idx;
172  index_t end=node->data.end_idx;
173 
174  for (int32_t i=start;i<=end;i++)
175  heap->push(m_vec_id[i],distance(m_vec_id[i],arr,dim));
176 
177  return;
178  }
179 
180  bnode_t* cleft=node->left();
181  bnode_t* cright=node->right();
182 
183  float64_t min_dist_left=min_dist(cleft,arr,dim);
184  float64_t min_dist_right=min_dist(cright,arr,dim);
185 
186  if (min_dist_left<=min_dist_right)
187  {
188  query_knn_single(heap,min_dist_left,cleft,arr,dim);
189  query_knn_single(heap,min_dist_right,cright,arr,dim);
190  }
191  else
192  {
193  query_knn_single(heap,min_dist_right,cright,arr,dim);
194  query_knn_single(heap,min_dist_left,cleft,arr,dim);
195  }
196 
197  SG_UNREF(cleft);
198  SG_UNREF(cright);
199 }
200 
202 {
203  float64_t ret=0;
204  for (int32_t i=0;i<dim;i++)
205  ret+=add_dim_dist(m_data(i,vec)-arr[i]);
206 
207  return actual_dists(ret);
208 }
209 
210 CBinaryTreeMachineNode<NbodyTreeNodeData>* CNbodyTree::recursive_build(index_t start, index_t end)
211 {
212  bnode_t* node=new bnode_t();
213  init_node(node,start,end);
214 
215  // stopping critertia
216  if (end-start+1<m_leaf_size*2)
217  {
218  node->data.is_leaf=true;
219  return node;
220  }
221 
222  node->data.is_leaf=false;
223  index_t dim=find_split_dim(node);
224  index_t mid=(end+start)/2;
225  partition(dim,start,end,mid);
226 
227  bnode_t* child_left=recursive_build(start,mid);
228  bnode_t* child_right=recursive_build(mid+1,end);
229 
230  node->left(child_left);
231  node->right(child_right);
232 
233  return node;
234 }
235 
236 void CNbodyTree::get_kde_single(bnode_t* node,float64_t* data, EKernelType kernel, float64_t h, float64_t log_atol, float64_t log_rtol,
237  float64_t log_norm, float64_t min_bound_node, float64_t spread_node, float64_t &min_bound_global, float64_t &spread_global)
238 {
239  int32_t n_node=CMath::log(node->data.end_idx-node->data.start_idx+1);
240  int32_t n_total=CMath::log(m_data.num_cols);
241 
242  // local bound criterion met
243  if ((log_norm+spread_node+n_total-n_node)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_node))
244  return;
245 
246  // global bound criterion met
247  if ((log_norm+spread_global)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_global))
248  return;
249 
250  // node is leaf
251  if (node->data.is_leaf)
252  {
253  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
254  spread_global=logdiffexp(spread_global,spread_node);
255 
256  for (int32_t i=node->data.start_idx;i<=node->data.end_idx;i++)
257  {
259  min_bound_global=logsumexp(pt_eval,min_bound_global);
260  }
261 
262  return;
263  }
264 
265  bnode_t* lchild=node->left();
266  bnode_t* rchild=node->right();
267 
268  float64_t lower_dist=0;
269  float64_t upper_dist=0;
270  min_max_dist(data,lchild,lower_dist,upper_dist,m_data.num_rows);
271 
272  int32_t n_l=lchild->data.end_idx-lchild->data.start_idx+1;
273  float64_t lower_bound_childl=CMath::log(n_l)+CKernelDensity::log_kernel(kernel,upper_dist,h);
274  float64_t spread_childl=logdiffexp(log(n_l)+CKernelDensity::log_kernel(kernel,lower_dist,h),lower_bound_childl);
275 
276  min_max_dist(data,rchild,lower_dist,upper_dist,m_data.num_rows);
277  int32_t n_r=rchild->data.end_idx-rchild->data.start_idx+1;
278  float64_t lower_bound_childr=CMath::log(n_r)+CKernelDensity::log_kernel(kernel,upper_dist,h);
279  float64_t spread_childr=logdiffexp(log(n_r)+CKernelDensity::log_kernel(kernel,lower_dist,h),lower_bound_childr);
280 
281  // update global bounds
282  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
283  min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
284  min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
285 
286  spread_global=logdiffexp(spread_global,spread_node);
287  spread_global=logsumexp(spread_global,spread_childl);
288  spread_global=logsumexp(spread_global,spread_childr);
289 
290  get_kde_single(lchild,data,kernel,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl,min_bound_global,spread_global);
291  get_kde_single(rchild,data,kernel,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr,min_bound_global,spread_global);
292 
293  SG_UNREF(lchild);
294  SG_UNREF(rchild);
295 }
296 
297 void CNbodyTree::kde_dual(bnode_t* refnode, bnode_t* querynode, SGVector<index_t> qid, SGMatrix<float64_t> qdata, SGVector<float64_t> log_density, EKernelType kernel_type, float64_t h, float64_t log_atol, float64_t log_rtol, float64_t log_norm, float64_t min_bound_node, float64_t spread_node, float64_t &min_bound_global, float64_t &spread_global)
298 {
299  int32_t dim=m_data.num_rows;
300  float64_t n_node=CMath::log(refnode->data.end_idx-refnode->data.start_idx+1)+CMath::log(querynode->data.end_idx-querynode->data.start_idx+1);
302 
303  bool global_criterion=(log_norm+spread_global)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_global);
304  bool local_criterion=(log_norm+spread_node+n_total-n_node)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_node);
305 
306  // global bound criterion met || local bound criterion met
307  if (global_criterion || local_criterion)
308  {
309  // log density of all query points in the node is increased by K(mean + spread/2)
310  float64_t center_density=logsumexp(min_bound_node,spread_node-CMath::log(2))-CMath::log(querynode->data.end_idx-querynode->data.start_idx+1);
311  for (int32_t i=querynode->data.start_idx;i<=querynode->data.end_idx;i++)
312  log_density[qid[i]]=logsumexp(log_density[qid[i]],center_density);
313 
314  return;
315  }
316 
317  // both are leaves
318  if (refnode->data.is_leaf && querynode->data.is_leaf)
319  {
320  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
321  spread_global=logdiffexp(spread_global,spread_node);
322 
323  // point by point evavuation of density
324  for (int32_t i=querynode->data.start_idx;i<=querynode->data.end_idx;i++)
325  {
327  for (int32_t j=refnode->data.start_idx;j<=refnode->data.end_idx;j++)
328  {
329  float64_t pt_eval=CKernelDensity::log_kernel(kernel_type,distance(m_vec_id[j],qdata.matrix+dim*qid[i],dim),h);
330  q=logsumexp(q,pt_eval);
331  }
332 
333  min_bound_global=logsumexp(min_bound_global,q);
334  log_density[qid[i]]=logsumexp(log_density[qid[i]],q);
335  }
336 
337  return;
338  }
339 
340  // if query node is leaf - just recurse on the reference tree
341  if (querynode->data.is_leaf)
342  {
343  bnode_t* lchild=refnode->left();
344  bnode_t* rchild=refnode->right();
345  int32_t queryn=querynode->data.end_idx-querynode->data.start_idx+1;
346 
347  // compute bounds for query node and left child of ref node
348  float64_t lower_dist=min_dist_dual(querynode,lchild);
349  float64_t upper_dist=max_dist_dual(querynode,lchild);
350  int32_t refn_l=lchild->data.end_idx-lchild->data.start_idx+1;
351  float64_t lower_bound_childl=CMath::log(queryn)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
352  float64_t spread_childl=logdiffexp(CMath::log(queryn)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_childl);
353 
354  // compute bounds for query node and right child of ref node
355  lower_dist=min_dist_dual(querynode,rchild);
356  upper_dist=max_dist_dual(querynode,rchild);
357  int32_t refn_r=rchild->data.end_idx-rchild->data.start_idx+1;
358  float64_t lower_bound_childr=CMath::log(queryn)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
359  float64_t spread_childr=logdiffexp(CMath::log(queryn)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_childr);
360 
361  // update global bounds
362  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
363  min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
364  min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
365 
366  spread_global=logdiffexp(spread_global,spread_node);
367  spread_global=logsumexp(spread_global,spread_childl);
368  spread_global=logsumexp(spread_global,spread_childr);
369 
370  kde_dual(lchild,querynode,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl, min_bound_global,spread_global);
371  kde_dual(rchild,querynode,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr, min_bound_global,spread_global);
372 
373  SG_UNREF(lchild);
374  SG_UNREF(rchild);
375  return;
376  }
377 
378  // if reference node is leaf - just recurse on the query tree
379  if (refnode->data.is_leaf)
380  {
381  int32_t ref_n=refnode->data.end_idx-refnode->data.start_idx+1;
382  bnode_t* lchild=querynode->left();
383  bnode_t* rchild=querynode->right();
384 
385  int32_t query_nl=lchild->data.end_idx-lchild->data.start_idx+1;
386  int32_t query_nr=rchild->data.end_idx-rchild->data.start_idx+1;
387 
388  // compute bounds for left child of query node and ref node
389  float64_t lower_dist=min_dist_dual(refnode,lchild);
390  float64_t upper_dist=max_dist_dual(refnode,lchild);
391  float64_t lower_bound_childl=CMath::log(query_nl)+CMath::log(ref_n)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
392  float64_t spread_childl=logdiffexp(CMath::log(query_nl)+CMath::log(ref_n)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_childl);
393 
394  // compute bounds for right child of query node and ref node
395  lower_dist=min_dist_dual(querynode,rchild);
396  upper_dist=max_dist_dual(querynode,rchild);
397  float64_t lower_bound_childr=CMath::log(query_nr)+CMath::log(ref_n)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
398  float64_t spread_childr=logdiffexp(CMath::log(query_nr)+CMath::log(ref_n)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_childr);
399 
400  // update global bounds
401  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
402  min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
403  min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
404 
405  spread_global=logdiffexp(spread_global,spread_node);
406  spread_global=logsumexp(spread_global,spread_childl);
407  spread_global=logsumexp(spread_global,spread_childr);
408 
409  kde_dual(refnode,lchild,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl,min_bound_global,spread_global);
410  kde_dual(refnode,rchild,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr,min_bound_global,spread_global);
411 
412  SG_UNREF(lchild);
413  SG_UNREF(rchild);
414  return;
415  }
416 
417  // if none of above - apply 4 way recursion in both trees: left-left, left-right, right-left, right-right
418  bnode_t* refchildl=refnode->left();
419  bnode_t* refchildr=refnode->right();
420  bnode_t* querychildl=querynode->left();
421  bnode_t* querychildr=querynode->right();
422 
423  float64_t refn_l=refchildl->data.end_idx-refchildl->data.start_idx+1;
424  float64_t refn_r=refchildr->data.end_idx-refchildr->data.start_idx+1;
425  float64_t queryn_l=querychildl->data.end_idx-querychildl->data.start_idx+1;
426  float64_t queryn_r=querychildr->data.end_idx-querychildr->data.start_idx+1;
427 
428  // left child-left child bounds
429  float64_t lower_dist=min_dist_dual(querychildl,refchildl);
430  float64_t upper_dist=max_dist_dual(querychildl,refchildl);
431  float64_t lower_bound_ll=CMath::log(queryn_l)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
432  float64_t spread_ll=logdiffexp(CMath::log(queryn_l)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_ll);
433 
434  // left-right bounds
435  lower_dist=min_dist_dual(querychildl,refchildr);
436  upper_dist=max_dist_dual(querychildl,refchildr);
437  float64_t lower_bound_lr=CMath::log(queryn_l)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
438  float64_t spread_lr=logdiffexp(CMath::log(queryn_l)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_lr);
439 
440  // right-left bounds
441  lower_dist=min_dist_dual(querychildr,refchildl);
442  upper_dist=max_dist_dual(querychildr,refchildl);
443  float64_t lower_bound_rl=CMath::log(queryn_r)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
444  float64_t spread_rl=logdiffexp(CMath::log(queryn_r)+CMath::log(refn_l)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_rl);
445 
446  // right-right bounds
447  lower_dist=min_dist_dual(querychildr,refchildr);
448  upper_dist=max_dist_dual(querychildr,refchildr);
449  float64_t lower_bound_rr=CMath::log(queryn_r)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,upper_dist,h);
450  float64_t spread_rr=logdiffexp(CMath::log(queryn_r)+CMath::log(refn_r)+CKernelDensity::log_kernel(kernel_type,lower_dist,h),lower_bound_rr);
451 
452  // update global bound and spread
453  min_bound_global=logdiffexp(min_bound_global,min_bound_node);
454  min_bound_global=logsumexp(min_bound_global,lower_bound_ll);
455  min_bound_global=logsumexp(min_bound_global,lower_bound_lr);
456  min_bound_global=logsumexp(min_bound_global,lower_bound_rl);
457  min_bound_global=logsumexp(min_bound_global,lower_bound_rr);
458 
459  spread_global=logdiffexp(spread_global,spread_node);
460  spread_global=logsumexp(spread_global,spread_ll);
461  spread_global=logsumexp(spread_global,spread_lr);
462  spread_global=logsumexp(spread_global,spread_rl);
463  spread_global=logsumexp(spread_global,spread_rr);
464 
465  // left-left and left-right recursions
466  kde_dual(refchildl,querychildl,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_ll,spread_ll, min_bound_global,spread_global);
467  kde_dual(refchildr,querychildl,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_lr,spread_lr, min_bound_global,spread_global);
468 
469  // right-left and right-right recursions
470  kde_dual(refchildl,querychildr,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_rl,spread_rl, min_bound_global,spread_global);
471  kde_dual(refchildr,querychildr,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_rr,spread_rr, min_bound_global, spread_global);
472 
473  SG_UNREF(refchildl);
474  SG_UNREF(refchildr);
475  SG_UNREF(querychildl);
476  SG_UNREF(querychildr);
477 }
478 
479 void CNbodyTree::partition(index_t dim, index_t start, index_t end, index_t mid)
480 {
481  // in-place partial quick-sort
482  index_t left=start;
483  index_t right=end;
484  while (true)
485  {
486  index_t midindex=left;
487  for (int32_t i=left;i<right;i++)
488  {
489  if (m_data(dim,m_vec_id[i])<m_data(dim,m_vec_id[right]))
490  {
491  CMath::swap(*(m_vec_id.vector+i),*(m_vec_id.vector+midindex));
492  midindex+=1;
493  }
494  }
495 
496  CMath::swap(*(m_vec_id.vector+midindex),*(m_vec_id.vector+right));
497  if (midindex==mid)
498  break;
499  else if (midindex<mid)
500  left=midindex+1;
501  else
502  right=midindex-1;
503  }
504 }
505 
506 index_t CNbodyTree::find_split_dim(bnode_t* node)
507 {
508  SGVector<float64_t> upper_bounds=node->data.bbox_upper;
509  SGVector<float64_t> lower_bounds=node->data.bbox_lower;
510 
511  index_t max_dim=0;
512  float64_t max_spread=-1;
513  for (int32_t i=0;i<m_data.num_rows;i++)
514  {
515  float64_t spread=upper_bounds[i]-lower_bounds[i];
516  if (spread>max_spread)
517  {
518  max_spread=spread;
519  max_dim=i;
520  }
521  }
522 
523  return max_dim;
524 }
525 
526 void CNbodyTree::init()
527 {
529  m_leaf_size=1;
531  m_dist=D_EUCLIDEAN;
532  m_knn_done=false;
533  m_knn_dists=SGMatrix<float64_t>();
534  m_knn_indices=SGMatrix<index_t>();
535 
536  SG_ADD(&m_data,"m_data","data matrix",MS_NOT_AVAILABLE);
537  SG_ADD(&m_leaf_size,"m_leaf_size","leaf size",MS_NOT_AVAILABLE);
538  SG_ADD(&m_vec_id,"m_vec_id","id of vectors",MS_NOT_AVAILABLE);
539  SG_ADD(&m_knn_done,"knn_done","knn done or not",MS_NOT_AVAILABLE);
540  SG_ADD(&m_knn_dists,"m_knn_dists","knn distances",MS_NOT_AVAILABLE);
541  SG_ADD(&m_knn_indices,"knn_indices","knn indices",MS_NOT_AVAILABLE);
542 }

SHOGUN Machine Learning Toolbox - Documentation