SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CoverTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2011 by Singularity Institute for Artificial Intelligence
3  * All Rights Reserved
4  *
5  * Written by David Crane <dncrane@gmail.com>
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU Affero General Public License v3 as
9  * published by the Free Software Foundation and including the exceptions
10  * at http://opencog.org/wiki/Licenses
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU Affero General Public License
18  * along with this program; if not, write to:
19  * Free Software Foundation, Inc.,
20  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  * Shogun modifications by Sergey Lisitsyn
23  */
24 
25 #ifndef _COVER_TREE_H
26 #define _COVER_TREE_H
27 
28 #include <vector>
29 #include <algorithm>
30 #include <map>
31 #include <set>
32 #include <cmath>
33 #include <float.h>
34 #include <iostream>
35 
36 namespace shogun
37 {
38 
51 template<class Point>
52 class CoverTree
53 {
58  class CoverTreeNode
59  {
60  private:
61  //_childMap[i] is a vector of the node's children at level i
62  std::map<int,std::vector<CoverTreeNode*> > _childMap;
63  //_points is all of the points with distance 0 which are not equal.
64  std::vector<Point> _points;
65  public:
66  CoverTreeNode(const Point& p);
74  std::vector<CoverTreeNode*> getChildren(int level) const;
75  void addChild(int level, CoverTreeNode* p);
76  void removeChild(int level, CoverTreeNode* p);
77  void addPoint(const Point& p);
78  void removePoint(const Point& p);
79  const std::vector<Point>& getPoints() { return _points; }
80  double distance(const CoverTreeNode& p) const;
81 
82  bool isSingle() const;
83  bool hasPoint(const Point& p) const;
84 
85  const Point& getPoint() const;
86 
91  std::vector<CoverTreeNode*> getAllChildren() const;
92  }; // CoverTreeNode class
93  private:
94  typedef std::pair<double, CoverTreeNode*> distNodePair;
95 
96  CoverTreeNode* _root;
97  unsigned int _numNodes;
98  int _maxLevel;//base^_maxLevel should be the max distance
99  //between any 2 points
100  int _minLevel;//A level beneath which there are no more new nodes.
101 
102  std::vector<CoverTreeNode*>
103  kNearestNodes(const Point& p, const unsigned int& k) const;
107  bool insert_rec(const Point& p,
108  const std::vector<distNodePair>& Qi,
109  const int& level);
110 
115  distNodePair distance(const Point& p,
116  const std::vector<CoverTreeNode*>& Q);
117 
118 
119  void remove_rec(const Point& p,
120  std::map<int,std::vector<distNodePair> >& coverSets,
121  int level,
122  bool& multi);
123 
124  public:
126  static const double base;
127 
137  CoverTree(const double& maxDist,
138  const std::vector<Point>& points=std::vector<Point>());
139  ~CoverTree();
140 
148  bool isValidTree() const;
149 
157  void insert(const Point& newPoint);
158 
164  void remove(const Point& p);
165 
171  std::vector<Point> kNearestNeighbors(const Point& p, const unsigned int& k) const;
172 
177  CoverTreeNode* getRoot() const;
178 
179 }; // CoverTree class
180 
181 template<class Point>
182 const double CoverTree<Point>::base = 2.0;
183 
184 template<class Point>
185 CoverTree<Point>::CoverTree(const double& maxDist,
186  const std::vector<Point>& points)
187 {
188  _root=NULL;
189  _numNodes=0;
190  _maxLevel=ceilf(log(maxDist)/log(base));
191  _minLevel=_maxLevel-1;
192  typename std::vector<Point>::const_iterator it;
193  for(it=points.begin(); it!=points.end(); ++it) {
194  this->insert(*it);
195  }
196 }
197 
198 template<class Point>
200 {
201  if(_root==NULL) return;
202  //Get all of the root's children (from any level),
203  //delete the root, repeat for each of the children
204  std::vector<CoverTreeNode*> nodes;
205  nodes.push_back(_root);
206  while(!nodes.empty()) {
207  CoverTreeNode* byeNode = nodes[0];
208  nodes.erase(nodes.begin());
209  std::vector<CoverTreeNode*> children = byeNode->getAllChildren();
210  nodes.insert(nodes.begin(),children.begin(),children.end());
211  //std::cout << _numNodes << "\n";
212  delete byeNode;
213  //_numNodes--;
214  }
215 }
216 
217 template<class Point>
218 std::vector<typename CoverTree<Point>::CoverTreeNode*>
219 CoverTree<Point>::kNearestNodes(const Point& p, const unsigned int& k) const
220 {
221  if(_root==NULL) return std::vector<CoverTreeNode*>();
222  //maxDist is the kth nearest known point to p, and also the farthest
223  //point from p in the set minNodes defined below.
224  double maxDist = p.distance(_root->getPoint());
225  //minNodes stores the k nearest known points to p.
226  std::set<distNodePair> minNodes;
227 
228  minNodes.insert(std::make_pair(maxDist,_root));
229  std::vector<distNodePair> Qj(1,std::make_pair(maxDist,_root));
230  for(int level = _maxLevel; level>=_minLevel;level--) {
231  typename std::vector<distNodePair>::const_iterator it;
232  int size = Qj.size();
233  for(int i=0; i<size; i++) {
234  std::vector<CoverTreeNode*> children =
235  Qj[i].second->getChildren(level);
236  typename std::vector<CoverTreeNode*>::const_iterator it2;
237  for(it2=children.begin(); it2!=children.end(); ++it2) {
238  double d = p.distance((*it2)->getPoint());
239  if(d < maxDist || minNodes.size() < k) {
240  minNodes.insert(std::make_pair(d,*it2));
241  //--minNodes.end() gives us an iterator to the greatest
242  //element of minNodes.
243  if(minNodes.size() > k) minNodes.erase(--minNodes.end());
244  maxDist = (--minNodes.end())->first;
245  }
246  Qj.push_back(std::make_pair(d,*it2));
247  }
248  }
249  double sep = maxDist + pow(base, level);
250  size = Qj.size();
251  for(int i=0; i<size; i++) {
252  if(Qj[i].first > sep) {
253  //quickly removes an element from a vector w/o preserving order.
254  Qj[i]=Qj.back();
255  Qj.pop_back();
256  size--; i--;
257  }
258  }
259  }
260  std::vector<CoverTreeNode*> kNN;
261  typename std::set<distNodePair>::const_iterator it;
262  for(it=minNodes.begin();it!=minNodes.end();++it) {
263  kNN.push_back(it->second);
264  }
265  return kNN;
266 }
267 template<class Point>
268 bool CoverTree<Point>::insert_rec(const Point& p,
269  const std::vector<distNodePair>& Qi,
270  const int& level)
271 {
272  std::vector<std::pair<double, CoverTreeNode*> > Qj;
273  double sep = pow(base,level);
274  double minDist = DBL_MAX;
275  std::pair<double,CoverTreeNode*> minQiDist(DBL_MAX,NULL);
276  typename std::vector<std::pair<double, CoverTreeNode*> >::const_iterator it;
277  for(it=Qi.begin(); it!=Qi.end(); ++it) {
278  if(it->first<minQiDist.first) minQiDist = *it;
279  if(it->first<minDist) minDist=it->first;
280  if(it->first<=sep) Qj.push_back(*it);
281  std::vector<CoverTreeNode*> children = it->second->getChildren(level);
282  typename std::vector<CoverTreeNode*>::const_iterator it2;
283  for(it2=children.begin();it2!=children.end();++it2) {
284  double d = p.distance((*it2)->getPoint());
285  if(d<minDist) minDist = d;
286  if(d<=sep) {
287  Qj.push_back(std::make_pair(d,*it2));
288  }
289  }
290  }
291  //std::cout << "level: " << level << ", sep: " << sep << ", dist: " << minQDist.first << "\n";
292  if(minDist > sep) {
293  return true;
294  } else {
295  bool found = insert_rec(p,Qj,level-1);
296  //distNodePair minQiDist = distance(p,Qi);
297  if(found && minQiDist.first <= sep) {
298  if(level-1<_minLevel) _minLevel=level-1;
299  minQiDist.second->addChild(level,
300  new CoverTreeNode(p));
301  //std::cout << "parent is ";
302  //minQiDist.second->getPoint().print();
303  _numNodes++;
304  return false;
305  } else {
306  return found;
307  }
308  }
309 }
310 
311 template<class Point>
312 void CoverTree<Point>::remove_rec(const Point& p,
313  std::map<int,std::vector<distNodePair> >& coverSets,
314  int level,
315  bool& multi)
316 {
317  std::vector<distNodePair>& Qi = coverSets[level];
318  std::vector<distNodePair>& Qj = coverSets[level-1];
319  double minDist = DBL_MAX;
320  CoverTreeNode* minNode = _root;
321  CoverTreeNode* parent = 0;
322  double sep = pow(base, level);
323  typename std::vector<distNodePair>::const_iterator it_;
324  //set Qj to be all children q of Qi such that p.distance(q)<=sep
325  //and also keep track of the minimum distance from p to a node in Qj
326  //note that every node has itself as a child, but the
327  //getChildren function only returns non-self-children.
328  for(it_=Qi.begin();it_!=Qi.end();++it_) {
329  std::vector<CoverTreeNode*> children = it_->second->getChildren(level);
330  double dist = it_->first;
331  if(dist<minDist) {
332  minDist = dist;
333  minNode = it_->second;
334  }
335  if(dist <= sep) {
336  Qj.push_back(*it_);
337  }
338  typename std::vector<CoverTreeNode*>::const_iterator it2;
339  for(it2=children.begin();it2!=children.end();++it2) {
340  dist = p.distance((*it2)->getPoint());
341  if(dist<minDist) {
342  minDist = dist;
343  minNode = *it2;
344  if(dist == 0.0) parent = it_->second;
345  }
346  if(dist <= sep) {
347  Qj.push_back(std::make_pair(dist,*it2));
348  }
349  }
350  }
351  if(level>_minLevel) remove_rec(p,coverSets,level-1,multi);
352  if(minNode->hasPoint(p)) {
353  //the multi flag indicates the point we removed is from a
354  //node containing multiple points, and we have removed it,
355  //so we don't need to do anything else.
356  if(multi) return;
357  if(!minNode->isSingle()) {
358  minNode->removePoint(p);
359  multi=true;
360  return;
361  }
362  if(parent!=NULL) parent->removeChild(level, minNode);
363  std::vector<CoverTreeNode*> children = minNode->getChildren(level-1);
364  std::vector<distNodePair>& Q = coverSets[level-1];
365  if(Q.size()==1 && Q[0].second==minNode) {
366  Q.pop_back();
367  } else {
368  for(unsigned int i=0;i<Q.size();i++) {
369  if(Q[i].second==minNode) {
370  Q[i]=Q.back();
371  Q.pop_back();
372  break;
373  }
374  }
375  }
376  typename std::vector<CoverTreeNode*>::const_iterator it;
377  for(it=children.begin();it!=children.end();++it) {
378  int i = level-1;
379  Point q = (*it)->getPoint();
380  double minDQ = DBL_MAX;
381  CoverTreeNode* minDQNode;
382  double sep_ = pow(base,i);
383  bool br=false;
384  while(true) {
385  std::vector<distNodePair>&
386  Q_ = coverSets[i];
387  typename std::vector<distNodePair>::const_iterator it2;
388  minDQ = DBL_MAX;
389  for(it2=Q_.begin();it2!=Q_.end();++it2) {
390  double d = q.distance(it2->second->getPoint());
391  if(d<minDQ) {
392  minDQ = d;
393  minDQNode = it2->second;
394  if(d <=sep_) {
395  br=true;
396  break;
397  }
398  }
399  }
400  minDQ=DBL_MAX;
401  if(br) break;
402  Q_.push_back(std::make_pair((*it)->distance(p),*it));
403  i++;
404  sep_ = pow(base,i);
405  }
406  //minDQNode->getPoint().print();
407  //std::cout << " is level " << i << " parent of ";
408  //(*it)->getPoint().print();
409  minDQNode->addChild(i,*it);
410  }
411  if(parent!=NULL) {
412  delete minNode;
413  _numNodes--;
414  }
415  }
416 }
417 
418 template<class Point>
419 std::pair<double, typename CoverTree<Point>::CoverTreeNode*>
420 CoverTree<Point>::distance(const Point& p,
421  const std::vector<CoverTreeNode*>& Q)
422 {
423  double minDist = DBL_MAX;
424  CoverTreeNode* minNode;
425  typename std::vector<CoverTreeNode*>::const_iterator it;
426  for(it=Q.begin();it!=Q.end();++it) {
427  double dist = p.distance((*it)->getPoint());
428  if(dist < minDist) {
429  minDist = dist;
430  minNode = *it;
431  }
432  }
433  return std::make_pair(minDist,minNode);
434 }
435 
436 template<class Point>
437 void CoverTree<Point>::insert(const Point& newPoint)
438 {
439  if(_root==NULL) {
440  _root = new CoverTreeNode(newPoint);
441  _numNodes=1;
442  return;
443  }
444  //TODO: this is pretty inefficient, there may be a better way
445  //to check if the node already exists...
446  CoverTreeNode* n = kNearestNodes(newPoint,1)[0];
447  if(newPoint.distance(n->getPoint())==0.0) {
448  n->addPoint(newPoint);
449  } else {
450  //insert_rec acts under the assumption that there are no nodes with
451  //distance 0 to newPoint in the cover tree (the previous lines check it)
452  insert_rec(newPoint,
453  std::vector<distNodePair>
454  (1,std::make_pair(_root->distance(newPoint),_root)),
455  _maxLevel);
456  }
457 }
458 
459 template<class Point>
460 void CoverTree<Point>::remove(const Point& p)
461 {
462  //Most of this function's code is for the special case of removing the root
463  if(_root==NULL) return;
464  bool removingRoot=_root->hasPoint(p);
465  if(removingRoot && !_root->isSingle()) {
466  _root->removePoint(p);
467  return;
468  }
469  CoverTreeNode* newRoot=NULL;
470  if(removingRoot) {
471  if(_numNodes==1) {
472  //removing the last node...
473  delete _root;
474  _numNodes--;
475  _root=NULL;
476  return;
477  } else {
478  for(int i=_maxLevel;i>_minLevel;i--) {
479  if(!(_root->getChildren(i).empty())) {
480  newRoot = _root->getChildren(i).back();
481  _root->removeChild(i,newRoot);
482  break;
483  }
484  }
485  }
486  }
487  std::map<int, std::vector<distNodePair> > coverSets;
488  coverSets[_maxLevel].push_back(std::make_pair(_root->distance(p),_root));
489  if(removingRoot)
490  coverSets[_maxLevel].push_back(std::make_pair(newRoot->distance(p),newRoot));
491  bool multi = false;
492  remove_rec(p,coverSets,_maxLevel,multi);
493  if(removingRoot) {
494  delete _root;
495  _numNodes--;
496  _root=newRoot;
497  }
498 }
499 
500 template<class Point>
501 std::vector<Point> CoverTree<Point>::kNearestNeighbors(const Point& p,
502  const unsigned int& k) const
503 {
504  if(_root==NULL) return std::vector<Point>();
505  std::vector<CoverTreeNode*> v = kNearestNodes(p, k);
506  std::vector<Point> kNN;
507  typename std::vector<CoverTreeNode*>::const_iterator it;
508  for(it=v.begin();it!=v.end();++it) {
509  const std::vector<Point>& po = (*it)->getPoints();
510  kNN.insert(kNN.end(),po.begin(),po.end());
511  if(kNN.size() >= k) break;
512  }
513  return kNN;
514 }
515 
516 template<class Point>
518 {
519  return _root;
520 }
521 
522 template<class Point>
524  _points.push_back(p);
525 }
526 
527 template<class Point>
528 std::vector<typename CoverTree<Point>::CoverTreeNode*>
529 CoverTree<Point>::CoverTreeNode::getChildren(int level) const
530 {
531  typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator
532  it = _childMap.find(level);
533  if(it!=_childMap.end()) {
534  return it->second;
535  }
536  return std::vector<CoverTreeNode*>();
537 }
538 
539 template<class Point>
540 void CoverTree<Point>::CoverTreeNode::addChild(int level, CoverTreeNode* p)
541 {
542  _childMap[level].push_back(p);
543 }
544 
545 template<class Point>
546 void CoverTree<Point>::CoverTreeNode::removeChild(int level, CoverTreeNode* p)
547 {
548  std::vector<CoverTreeNode*>& v = _childMap[level];
549  for(unsigned int i=0;i<v.size();i++) {
550  if(v[i]==p) {
551  v[i]=v.back();
552  v.pop_back();
553  break;
554  }
555  }
556 }
557 
558 template<class Point>
559 void CoverTree<Point>::CoverTreeNode::addPoint(const Point& p)
560 {
561  if(find(_points.begin(), _points.end(), p) == _points.end())
562  _points.push_back(p);
563 }
564 
565 template<class Point>
566 void CoverTree<Point>::CoverTreeNode::removePoint(const Point& p)
567 {
568  typename std::vector<Point>::iterator it =
569  find(_points.begin(), _points.end(), p);
570  if(it != _points.end())
571  _points.erase(it);
572 }
573 
574 template<class Point>
575 double CoverTree<Point>::CoverTreeNode::distance(const CoverTreeNode& p) const
576 {
577  return _points[0].distance(p.getPoint());
578 }
579 
580 template<class Point>
581 bool CoverTree<Point>::CoverTreeNode::isSingle() const
582 {
583  return _points.size() == 1;
584 }
585 
586 template<class Point>
587 bool CoverTree<Point>::CoverTreeNode::hasPoint(const Point& p) const
588 {
589  return find(_points.begin(), _points.end(), p) != _points.end();
590 }
591 
592 template<class Point>
593 const Point& CoverTree<Point>::CoverTreeNode::getPoint() const { return _points[0]; }
594 
595 template<class Point>
596 std::vector<typename CoverTree<Point>::CoverTreeNode*>
597 CoverTree<Point>::CoverTreeNode::getAllChildren() const
598 {
599  std::vector<CoverTreeNode*> children;
600  typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator it;
601  for(it=_childMap.begin();it!=_childMap.end();++it) {
602  children.insert(children.end(), it->second.begin(), it->second.end());
603  }
604  return children;
605 }
606 
607 template<class Point>
609  if(_numNodes==0)
610  return _root==NULL;
611 
612  std::vector<CoverTreeNode*> nodes;
613  nodes.push_back(_root);
614  for(int i=_maxLevel;i>_minLevel;i--) {
615  double sep = pow(base,i);
616  typename std::vector<CoverTreeNode*>::const_iterator it, it2;
617  //verify separation invariant of cover tree: for each level,
618  //every point is farther than base^level away
619  for(it=nodes.begin(); it!=nodes.end(); ++it) {
620  for(it2=nodes.begin(); it2!=nodes.end(); ++it2) {
621  double dist=(*it)->distance((*it2)->getPoint());
622  if(dist<=sep && dist!=0.0) {
623  std::cout << "Level " << i << " Separation invariant failed.\n";
624  return false;
625  }
626  }
627  }
628  std::vector<CoverTreeNode*> allChildren;
629  for(it=nodes.begin(); it!=nodes.end(); ++it) {
630  std::vector<CoverTreeNode*> children = (*it)->getChildren(i);
631  //verify covering tree invariant: the children of node n at level
632  //i are no further than base^i away
633  for(it2=children.begin(); it2!=children.end(); ++it2) {
634  double dist = (*it2)->distance((*it)->getPoint());
635  if(dist>sep) {
636  std::cout << "Level" << i << " covering tree invariant failed.n";
637  return false;
638  }
639  }
640  allChildren.insert
641  (allChildren.end(),children.begin(),children.end());
642  }
643  nodes.insert(nodes.begin(),allChildren.begin(),allChildren.end());
644  }
645  return true;
646 }
647 }
648 #endif // _COVER_TREE_H
649 

SHOGUN Machine Learning Toolbox - Documentation