41 m_leaf_size=leaf_size;
48 REQUIRE(m_leaf_size>0,
"Leaf size should be greater than 0\n");
61 REQUIRE(data,
"Query data not supplied\n")
70 for (int32_t i=0;i<qfeats.num_cols;i++)
78 query_knn_single(heap,mdist,root,qfeats.matrix+i*dim,dim);
89 REQUIRE(test.
num_rows==dim,
"dimensions of training data and test data should be the same\n")
95 for (int32_t i=0;i<test.
num_cols;i++)
107 float64_t spread=logdiffexp(max_bound,min_bound);
109 get_kde_single(root,test.
matrix+i*dim,kernel,h,log_atol,log_rtol,log_kernel_norm,min_bound,spread,min_bound,spread);
119 REQUIRE(test.
num_rows==dim,
"dimensions of training data and test data should be the same\n")
135 float64_t spread=logdiffexp(max_bound,min_bound);
137 kde_dual(rroot,qroot,qid,test,log_density,kernel,h,log_atol,log_rtol,log_kernel_norm,min_bound,spread,min_bound,spread);
140 for (int32_t i=0;i<test.
num_cols;i++)
141 log_density[i]=log_density[i]+log_kernel_norm-log_n;
151 SG_ERROR(
"knn query has not been executed yet\n");
158 return m_knn_indices;
160 SG_ERROR(
"knn query has not been executed yet\n");
169 if (node->
data.is_leaf)
174 for (int32_t i=start;i<=end;i++)
186 if (min_dist_left<=min_dist_right)
188 query_knn_single(heap,min_dist_left,cleft,arr,dim);
189 query_knn_single(heap,min_dist_right,cright,arr,dim);
193 query_knn_single(heap,min_dist_right,cright,arr,dim);
194 query_knn_single(heap,min_dist_left,cleft,arr,dim);
204 for (int32_t i=0;i<dim;i++)
216 if (end-start+1<m_leaf_size*2)
218 node->
data.is_leaf=
true;
222 node->
data.is_leaf=
false;
223 index_t dim=find_split_dim(node);
225 partition(dim,start,end,mid);
227 bnode_t* child_left=recursive_build(start,mid);
228 bnode_t* child_right=recursive_build(mid+1,end);
230 node->
left(child_left);
231 node->
right(child_right);
243 if ((log_norm+spread_node+n_total-n_node)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_node))
247 if ((log_norm+spread_global)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_global))
251 if (node->
data.is_leaf)
253 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
254 spread_global=logdiffexp(spread_global,spread_node);
256 for (int32_t i=node->
data.start_idx;i<=node->data.end_idx;i++)
259 min_bound_global=logsumexp(pt_eval,min_bound_global);
272 int32_t n_l=lchild->
data.end_idx-lchild->
data.start_idx+1;
277 int32_t n_r=rchild->
data.end_idx-rchild->
data.start_idx+1;
282 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
283 min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
284 min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
286 spread_global=logdiffexp(spread_global,spread_node);
287 spread_global=logsumexp(spread_global,spread_childl);
288 spread_global=logsumexp(spread_global,spread_childr);
290 get_kde_single(lchild,data,kernel,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl,min_bound_global,spread_global);
291 get_kde_single(rchild,data,kernel,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr,min_bound_global,spread_global);
297 void CNbodyTree::kde_dual(
bnode_t* refnode,
bnode_t* querynode,
SGVector<index_t> qid,
SGMatrix<float64_t> qdata,
SGVector<float64_t> log_density,
EKernelType kernel_type,
float64_t h,
float64_t log_atol,
float64_t log_rtol,
float64_t log_norm,
float64_t min_bound_node,
float64_t spread_node,
float64_t &min_bound_global,
float64_t &spread_global)
303 bool global_criterion=(log_norm+spread_global)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_global);
304 bool local_criterion=(log_norm+spread_node+n_total-n_node)<=logsumexp(log_atol,log_rtol+log_norm+min_bound_node);
307 if (global_criterion || local_criterion)
311 for (int32_t i=querynode->
data.start_idx;i<=querynode->data.end_idx;i++)
312 log_density[qid[i]]=logsumexp(log_density[qid[i]],center_density);
318 if (refnode->
data.is_leaf && querynode->
data.is_leaf)
320 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
321 spread_global=logdiffexp(spread_global,spread_node);
324 for (int32_t i=querynode->
data.start_idx;i<=querynode->data.end_idx;i++)
327 for (int32_t j=refnode->
data.start_idx;j<=refnode->data.end_idx;j++)
330 q=logsumexp(q,pt_eval);
333 min_bound_global=logsumexp(min_bound_global,q);
334 log_density[qid[i]]=logsumexp(log_density[qid[i]],q);
341 if (querynode->
data.is_leaf)
345 int32_t queryn=querynode->
data.end_idx-querynode->
data.start_idx+1;
350 int32_t refn_l=lchild->
data.end_idx-lchild->
data.start_idx+1;
357 int32_t refn_r=rchild->
data.end_idx-rchild->
data.start_idx+1;
362 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
363 min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
364 min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
366 spread_global=logdiffexp(spread_global,spread_node);
367 spread_global=logsumexp(spread_global,spread_childl);
368 spread_global=logsumexp(spread_global,spread_childr);
370 kde_dual(lchild,querynode,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl, min_bound_global,spread_global);
371 kde_dual(rchild,querynode,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr, min_bound_global,spread_global);
379 if (refnode->
data.is_leaf)
381 int32_t ref_n=refnode->
data.end_idx-refnode->
data.start_idx+1;
385 int32_t query_nl=lchild->
data.end_idx-lchild->
data.start_idx+1;
386 int32_t query_nr=rchild->
data.end_idx-rchild->
data.start_idx+1;
401 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
402 min_bound_global=logsumexp(min_bound_global,lower_bound_childl);
403 min_bound_global=logsumexp(min_bound_global,lower_bound_childr);
405 spread_global=logdiffexp(spread_global,spread_node);
406 spread_global=logsumexp(spread_global,spread_childl);
407 spread_global=logsumexp(spread_global,spread_childr);
409 kde_dual(refnode,lchild,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childl,spread_childl,min_bound_global,spread_global);
410 kde_dual(refnode,rchild,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_childr,spread_childr,min_bound_global,spread_global);
453 min_bound_global=logdiffexp(min_bound_global,min_bound_node);
454 min_bound_global=logsumexp(min_bound_global,lower_bound_ll);
455 min_bound_global=logsumexp(min_bound_global,lower_bound_lr);
456 min_bound_global=logsumexp(min_bound_global,lower_bound_rl);
457 min_bound_global=logsumexp(min_bound_global,lower_bound_rr);
459 spread_global=logdiffexp(spread_global,spread_node);
460 spread_global=logsumexp(spread_global,spread_ll);
461 spread_global=logsumexp(spread_global,spread_lr);
462 spread_global=logsumexp(spread_global,spread_rl);
463 spread_global=logsumexp(spread_global,spread_rr);
466 kde_dual(refchildl,querychildl,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_ll,spread_ll, min_bound_global,spread_global);
467 kde_dual(refchildr,querychildl,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_lr,spread_lr, min_bound_global,spread_global);
470 kde_dual(refchildl,querychildr,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_rl,spread_rl, min_bound_global,spread_global);
471 kde_dual(refchildr,querychildr,qid,qdata,log_density,kernel_type,h,log_atol,log_rtol,log_norm,lower_bound_rr,spread_rr, min_bound_global, spread_global);
487 for (int32_t i=left;i<right;i++)
499 else if (midindex<mid)
515 float64_t spread=upper_bounds[i]-lower_bounds[i];
516 if (spread>max_spread)
526 void CNbodyTree::init()
void range_fill(T start=0)
void push(index_t index, float64_t dist)
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
static void fill_vector(T *vec, int32_t len, T value)
int32_t get_num_features() const
SGVector< index_t > get_indices()
SGMatrix< index_t > get_knn_indices()
static const float64_t INFTY
infinity
static float64_t log_kernel(EKernelType kernel, float64_t dist, float64_t width)
SGMatrix< ST > get_feature_matrix()
float64_t distance(index_t vec, float64_t *arr, int32_t dim)
float64_t add_dim_dist(float64_t d)
CTreeMachineNode< NbodyTreeNodeData > * m_root
void build_tree(CDenseFeatures< float64_t > *data)
void set_root(CTreeMachineNode< NbodyTreeNodeData > *root)
structure to store data of a node of N-Body tree. This can be used as a template type in TreeMachineN...
SGMatrix< float64_t > m_data
virtual float64_t max_dist_dual(bnode_t *nodeq, bnode_t *noder)=0
CNbodyTree(int32_t leaf_size=1, EDistanceType d=D_EUCLIDEAN)
SGMatrix< float64_t > get_knn_dists()
virtual float64_t min_dist(bnode_t *node, float64_t *feat, int32_t dim)=0
void query_knn(CDenseFeatures< float64_t > *data, int32_t k)
static float64_t log_norm(EKernelType kernel, float64_t width, int32_t dim)
void right(CBinaryTreeMachineNode *r)
SGVector< float64_t > log_kernel_density_dual(SGMatrix< float64_t > test, SGVector< index_t > qid, bnode_t *qroot, EKernelType kernel, float64_t h, float64_t atol, float64_t rtol)
SGVector< float64_t > log_kernel_density(SGMatrix< float64_t > test, EKernelType kernel, float64_t h, float64_t atol, float64_t rtol)
float64_t actual_dists(float64_t dists)
virtual float64_t min_dist_dual(bnode_t *nodeq, bnode_t *noder)=0
all of classes and functions are contained in the shogun namespace
virtual void min_max_dist(float64_t *pt, bnode_t *node, float64_t &lower, float64_t &upper, int32_t dim)=0
SGVector< index_t > m_vec_id
CBinaryTreeMachineNode< NbodyTreeNodeData > bnode_t
static float64_t log(float64_t v)
static void swap(T &a, T &b)
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
virtual void init_node(bnode_t *node, index_t start, index_t end)=0
SGVector< float64_t > get_dists()
void left(CBinaryTreeMachineNode *l)
This class implements a specialized version of max heap structure. This heap specializes in storing t...