00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _COVER_TREE_H
00026 #define _COVER_TREE_H
00027
00028 #include <vector>
00029 #include <algorithm>
00030 #include <map>
00031 #include <set>
00032 #include <cmath>
00033 #include <float.h>
00034 #include <iostream>
00035
00036 namespace shogun
00037 {
00038
00051 template<class Point>
00052 class CoverTree
00053 {
00058 class CoverTreeNode
00059 {
00060 private:
00061
00062 std::map<int,std::vector<CoverTreeNode*> > _childMap;
00063
00064 std::vector<Point> _points;
00065 public:
00066 CoverTreeNode(const Point& p);
00074 std::vector<CoverTreeNode*> getChildren(int level) const;
00075 void addChild(int level, CoverTreeNode* p);
00076 void removeChild(int level, CoverTreeNode* p);
00077 void addPoint(const Point& p);
00078 void removePoint(const Point& p);
00079 const std::vector<Point>& getPoints() { return _points; }
00080 double distance(const CoverTreeNode& p) const;
00081
00082 bool isSingle() const;
00083 bool hasPoint(const Point& p) const;
00084
00085 const Point& getPoint() const;
00086
00091 std::vector<CoverTreeNode*> getAllChildren() const;
00092 };
00093 private:
00094 typedef std::pair<double, CoverTreeNode*> distNodePair;
00095
00096 CoverTreeNode* _root;
00097 unsigned int _numNodes;
00098 int _maxLevel;
00099
00100 int _minLevel;
00101
00102 std::vector<CoverTreeNode*>
00103 kNearestNodes(const Point& p, const unsigned int& k) const;
00107 bool insert_rec(const Point& p,
00108 const std::vector<distNodePair>& Qi,
00109 const int& level);
00110
00115 distNodePair distance(const Point& p,
00116 const std::vector<CoverTreeNode*>& Q);
00117
00118
00119 void remove_rec(const Point& p,
00120 std::map<int,std::vector<distNodePair> >& coverSets,
00121 int level,
00122 bool& multi);
00123
00124 public:
00126 static const double base;
00127
00137 CoverTree(const double& maxDist,
00138 const std::vector<Point>& points=std::vector<Point>());
00139 ~CoverTree();
00140
00148 bool isValidTree() const;
00149
00157 void insert(const Point& newPoint);
00158
00164 void remove(const Point& p);
00165
00171 std::vector<Point> kNearestNeighbors(const Point& p, const unsigned int& k) const;
00172
00177 CoverTreeNode* getRoot() const;
00178
00179 };
00180
00181 template<class Point>
00182 const double CoverTree<Point>::base = 2.0;
00183
00184 template<class Point>
00185 CoverTree<Point>::CoverTree(const double& maxDist,
00186 const std::vector<Point>& points)
00187 {
00188 _root=NULL;
00189 _numNodes=0;
00190 _maxLevel=ceilf(log(maxDist)/log(base));
00191 _minLevel=_maxLevel-1;
00192 typename std::vector<Point>::const_iterator it;
00193 for(it=points.begin(); it!=points.end(); ++it) {
00194 this->insert(*it);
00195 }
00196 }
00197
00198 template<class Point>
00199 CoverTree<Point>::~CoverTree()
00200 {
00201 if(_root==NULL) return;
00202
00203
00204 std::vector<CoverTreeNode*> nodes;
00205 nodes.push_back(_root);
00206 while(!nodes.empty()) {
00207 CoverTreeNode* byeNode = nodes[0];
00208 nodes.erase(nodes.begin());
00209 std::vector<CoverTreeNode*> children = byeNode->getAllChildren();
00210 nodes.insert(nodes.begin(),children.begin(),children.end());
00211
00212 delete byeNode;
00213
00214 }
00215 }
00216
00217 template<class Point>
00218 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00219 CoverTree<Point>::kNearestNodes(const Point& p, const unsigned int& k) const
00220 {
00221 if(_root==NULL) return std::vector<CoverTreeNode*>();
00222
00223
00224 double maxDist = p.distance(_root->getPoint());
00225
00226 std::set<distNodePair> minNodes;
00227
00228 minNodes.insert(std::make_pair(maxDist,_root));
00229 std::vector<distNodePair> Qj(1,std::make_pair(maxDist,_root));
00230 for(int level = _maxLevel; level>=_minLevel;level--) {
00231 typename std::vector<distNodePair>::const_iterator it;
00232 int size = Qj.size();
00233 for(int i=0; i<size; i++) {
00234 std::vector<CoverTreeNode*> children =
00235 Qj[i].second->getChildren(level);
00236 typename std::vector<CoverTreeNode*>::const_iterator it2;
00237 for(it2=children.begin(); it2!=children.end(); ++it2) {
00238 double d = p.distance((*it2)->getPoint());
00239 if(d < maxDist || minNodes.size() < k) {
00240 minNodes.insert(std::make_pair(d,*it2));
00241
00242
00243 if(minNodes.size() > k) minNodes.erase(--minNodes.end());
00244 maxDist = (--minNodes.end())->first;
00245 }
00246 Qj.push_back(std::make_pair(d,*it2));
00247 }
00248 }
00249 double sep = maxDist + pow(base, level);
00250 size = Qj.size();
00251 for(int i=0; i<size; i++) {
00252 if(Qj[i].first > sep) {
00253
00254 Qj[i]=Qj.back();
00255 Qj.pop_back();
00256 size--; i--;
00257 }
00258 }
00259 }
00260 std::vector<CoverTreeNode*> kNN;
00261 typename std::set<distNodePair>::const_iterator it;
00262 for(it=minNodes.begin();it!=minNodes.end();++it) {
00263 kNN.push_back(it->second);
00264 }
00265 return kNN;
00266 }
00267 template<class Point>
00268 bool CoverTree<Point>::insert_rec(const Point& p,
00269 const std::vector<distNodePair>& Qi,
00270 const int& level)
00271 {
00272 std::vector<std::pair<double, CoverTreeNode*> > Qj;
00273 double sep = pow(base,level);
00274 double minDist = DBL_MAX;
00275 std::pair<double,CoverTreeNode*> minQiDist(DBL_MAX,NULL);
00276 typename std::vector<std::pair<double, CoverTreeNode*> >::const_iterator it;
00277 for(it=Qi.begin(); it!=Qi.end(); ++it) {
00278 if(it->first<minQiDist.first) minQiDist = *it;
00279 if(it->first<minDist) minDist=it->first;
00280 if(it->first<=sep) Qj.push_back(*it);
00281 std::vector<CoverTreeNode*> children = it->second->getChildren(level);
00282 typename std::vector<CoverTreeNode*>::const_iterator it2;
00283 for(it2=children.begin();it2!=children.end();++it2) {
00284 double d = p.distance((*it2)->getPoint());
00285 if(d<minDist) minDist = d;
00286 if(d<=sep) {
00287 Qj.push_back(std::make_pair(d,*it2));
00288 }
00289 }
00290 }
00291
00292 if(minDist > sep) {
00293 return true;
00294 } else {
00295 bool found = insert_rec(p,Qj,level-1);
00296
00297 if(found && minQiDist.first <= sep) {
00298 if(level-1<_minLevel) _minLevel=level-1;
00299 minQiDist.second->addChild(level,
00300 new CoverTreeNode(p));
00301
00302
00303 _numNodes++;
00304 return false;
00305 } else {
00306 return found;
00307 }
00308 }
00309 }
00310
00311 template<class Point>
00312 void CoverTree<Point>::remove_rec(const Point& p,
00313 std::map<int,std::vector<distNodePair> >& coverSets,
00314 int level,
00315 bool& multi)
00316 {
00317 std::vector<distNodePair>& Qi = coverSets[level];
00318 std::vector<distNodePair>& Qj = coverSets[level-1];
00319 double minDist = DBL_MAX;
00320 CoverTreeNode* minNode = _root;
00321 CoverTreeNode* parent = 0;
00322 double sep = pow(base, level);
00323 typename std::vector<distNodePair>::const_iterator it_;
00324
00325
00326
00327
00328 for(it_=Qi.begin();it_!=Qi.end();++it_) {
00329 std::vector<CoverTreeNode*> children = it_->second->getChildren(level);
00330 double dist = it_->first;
00331 if(dist<minDist) {
00332 minDist = dist;
00333 minNode = it_->second;
00334 }
00335 if(dist <= sep) {
00336 Qj.push_back(*it_);
00337 }
00338 typename std::vector<CoverTreeNode*>::const_iterator it2;
00339 for(it2=children.begin();it2!=children.end();++it2) {
00340 dist = p.distance((*it2)->getPoint());
00341 if(dist<minDist) {
00342 minDist = dist;
00343 minNode = *it2;
00344 if(dist == 0.0) parent = it_->second;
00345 }
00346 if(dist <= sep) {
00347 Qj.push_back(std::make_pair(dist,*it2));
00348 }
00349 }
00350 }
00351 if(level>_minLevel) remove_rec(p,coverSets,level-1,multi);
00352 if(minNode->hasPoint(p)) {
00353
00354
00355
00356 if(multi) return;
00357 if(!minNode->isSingle()) {
00358 minNode->removePoint(p);
00359 multi=true;
00360 return;
00361 }
00362 if(parent!=NULL) parent->removeChild(level, minNode);
00363 std::vector<CoverTreeNode*> children = minNode->getChildren(level-1);
00364 std::vector<distNodePair>& Q = coverSets[level-1];
00365 if(Q.size()==1 && Q[0].second==minNode) {
00366 Q.pop_back();
00367 } else {
00368 for(unsigned int i=0;i<Q.size();i++) {
00369 if(Q[i].second==minNode) {
00370 Q[i]=Q.back();
00371 Q.pop_back();
00372 break;
00373 }
00374 }
00375 }
00376 typename std::vector<CoverTreeNode*>::const_iterator it;
00377 for(it=children.begin();it!=children.end();++it) {
00378 int i = level-1;
00379 Point q = (*it)->getPoint();
00380 double minDQ = DBL_MAX;
00381 CoverTreeNode* minDQNode;
00382 double sep_ = pow(base,i);
00383 bool br=false;
00384 while(true) {
00385 std::vector<distNodePair>&
00386 Q_ = coverSets[i];
00387 typename std::vector<distNodePair>::const_iterator it2;
00388 minDQ = DBL_MAX;
00389 for(it2=Q_.begin();it2!=Q_.end();++it2) {
00390 double d = q.distance(it2->second->getPoint());
00391 if(d<minDQ) {
00392 minDQ = d;
00393 minDQNode = it2->second;
00394 if(d <=sep_) {
00395 br=true;
00396 break;
00397 }
00398 }
00399 }
00400 minDQ=DBL_MAX;
00401 if(br) break;
00402 Q_.push_back(std::make_pair((*it)->distance(p),*it));
00403 i++;
00404 sep_ = pow(base,i);
00405 }
00406
00407
00408
00409 minDQNode->addChild(i,*it);
00410 }
00411 if(parent!=NULL) {
00412 delete minNode;
00413 _numNodes--;
00414 }
00415 }
00416 }
00417
00418 template<class Point>
00419 std::pair<double, typename CoverTree<Point>::CoverTreeNode*>
00420 CoverTree<Point>::distance(const Point& p,
00421 const std::vector<CoverTreeNode*>& Q)
00422 {
00423 double minDist = DBL_MAX;
00424 CoverTreeNode* minNode;
00425 typename std::vector<CoverTreeNode*>::const_iterator it;
00426 for(it=Q.begin();it!=Q.end();++it) {
00427 double dist = p.distance((*it)->getPoint());
00428 if(dist < minDist) {
00429 minDist = dist;
00430 minNode = *it;
00431 }
00432 }
00433 return std::make_pair(minDist,minNode);
00434 }
00435
00436 template<class Point>
00437 void CoverTree<Point>::insert(const Point& newPoint)
00438 {
00439 if(_root==NULL) {
00440 _root = new CoverTreeNode(newPoint);
00441 _numNodes=1;
00442 return;
00443 }
00444
00445
00446 CoverTreeNode* n = kNearestNodes(newPoint,1)[0];
00447 if(newPoint.distance(n->getPoint())==0.0) {
00448 n->addPoint(newPoint);
00449 } else {
00450
00451
00452 insert_rec(newPoint,
00453 std::vector<distNodePair>
00454 (1,std::make_pair(_root->distance(newPoint),_root)),
00455 _maxLevel);
00456 }
00457 }
00458
00459 template<class Point>
00460 void CoverTree<Point>::remove(const Point& p)
00461 {
00462
00463 if(_root==NULL) return;
00464 bool removingRoot=_root->hasPoint(p);
00465 if(removingRoot && !_root->isSingle()) {
00466 _root->removePoint(p);
00467 return;
00468 }
00469 CoverTreeNode* newRoot=NULL;
00470 if(removingRoot) {
00471 if(_numNodes==1) {
00472
00473 delete _root;
00474 _numNodes--;
00475 _root=NULL;
00476 return;
00477 } else {
00478 for(int i=_maxLevel;i>_minLevel;i--) {
00479 if(!(_root->getChildren(i).empty())) {
00480 newRoot = _root->getChildren(i).back();
00481 _root->removeChild(i,newRoot);
00482 break;
00483 }
00484 }
00485 }
00486 }
00487 std::map<int, std::vector<distNodePair> > coverSets;
00488 coverSets[_maxLevel].push_back(std::make_pair(_root->distance(p),_root));
00489 if(removingRoot)
00490 coverSets[_maxLevel].push_back(std::make_pair(newRoot->distance(p),newRoot));
00491 bool multi = false;
00492 remove_rec(p,coverSets,_maxLevel,multi);
00493 if(removingRoot) {
00494 delete _root;
00495 _numNodes--;
00496 _root=newRoot;
00497 }
00498 }
00499
00500 template<class Point>
00501 std::vector<Point> CoverTree<Point>::kNearestNeighbors(const Point& p,
00502 const unsigned int& k) const
00503 {
00504 if(_root==NULL) return std::vector<Point>();
00505 std::vector<CoverTreeNode*> v = kNearestNodes(p, k);
00506 std::vector<Point> kNN;
00507 typename std::vector<CoverTreeNode*>::const_iterator it;
00508 for(it=v.begin();it!=v.end();++it) {
00509 const std::vector<Point>& po = (*it)->getPoints();
00510 kNN.insert(kNN.end(),po.begin(),po.end());
00511 if(kNN.size() >= k) break;
00512 }
00513 return kNN;
00514 }
00515
00516 template<class Point>
00517 typename CoverTree<Point>::CoverTreeNode* CoverTree<Point>::getRoot() const
00518 {
00519 return _root;
00520 }
00521
00522 template<class Point>
00523 CoverTree<Point>::CoverTreeNode::CoverTreeNode(const Point& p) {
00524 _points.push_back(p);
00525 }
00526
00527 template<class Point>
00528 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00529 CoverTree<Point>::CoverTreeNode::getChildren(int level) const
00530 {
00531 typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator
00532 it = _childMap.find(level);
00533 if(it!=_childMap.end()) {
00534 return it->second;
00535 }
00536 return std::vector<CoverTreeNode*>();
00537 }
00538
00539 template<class Point>
00540 void CoverTree<Point>::CoverTreeNode::addChild(int level, CoverTreeNode* p)
00541 {
00542 _childMap[level].push_back(p);
00543 }
00544
00545 template<class Point>
00546 void CoverTree<Point>::CoverTreeNode::removeChild(int level, CoverTreeNode* p)
00547 {
00548 std::vector<CoverTreeNode*>& v = _childMap[level];
00549 for(unsigned int i=0;i<v.size();i++) {
00550 if(v[i]==p) {
00551 v[i]=v.back();
00552 v.pop_back();
00553 break;
00554 }
00555 }
00556 }
00557
00558 template<class Point>
00559 void CoverTree<Point>::CoverTreeNode::addPoint(const Point& p)
00560 {
00561 if(find(_points.begin(), _points.end(), p) == _points.end())
00562 _points.push_back(p);
00563 }
00564
00565 template<class Point>
00566 void CoverTree<Point>::CoverTreeNode::removePoint(const Point& p)
00567 {
00568 typename std::vector<Point>::iterator it =
00569 find(_points.begin(), _points.end(), p);
00570 if(it != _points.end())
00571 _points.erase(it);
00572 }
00573
00574 template<class Point>
00575 double CoverTree<Point>::CoverTreeNode::distance(const CoverTreeNode& p) const
00576 {
00577 return _points[0].distance(p.getPoint());
00578 }
00579
00580 template<class Point>
00581 bool CoverTree<Point>::CoverTreeNode::isSingle() const
00582 {
00583 return _points.size() == 1;
00584 }
00585
00586 template<class Point>
00587 bool CoverTree<Point>::CoverTreeNode::hasPoint(const Point& p) const
00588 {
00589 return find(_points.begin(), _points.end(), p) != _points.end();
00590 }
00591
00592 template<class Point>
00593 const Point& CoverTree<Point>::CoverTreeNode::getPoint() const { return _points[0]; }
00594
00595 template<class Point>
00596 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00597 CoverTree<Point>::CoverTreeNode::getAllChildren() const
00598 {
00599 std::vector<CoverTreeNode*> children;
00600 typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator it;
00601 for(it=_childMap.begin();it!=_childMap.end();++it) {
00602 children.insert(children.end(), it->second.begin(), it->second.end());
00603 }
00604 return children;
00605 }
00606
00607 template<class Point>
00608 bool CoverTree<Point>::isValidTree() const {
00609 if(_numNodes==0)
00610 return _root==NULL;
00611
00612 std::vector<CoverTreeNode*> nodes;
00613 nodes.push_back(_root);
00614 for(int i=_maxLevel;i>_minLevel;i--) {
00615 double sep = pow(base,i);
00616 typename std::vector<CoverTreeNode*>::const_iterator it, it2;
00617
00618
00619 for(it=nodes.begin(); it!=nodes.end(); ++it) {
00620 for(it2=nodes.begin(); it2!=nodes.end(); ++it2) {
00621 double dist=(*it)->distance((*it2)->getPoint());
00622 if(dist<=sep && dist!=0.0) {
00623 std::cout << "Level " << i << " Separation invariant failed.\n";
00624 return false;
00625 }
00626 }
00627 }
00628 std::vector<CoverTreeNode*> allChildren;
00629 for(it=nodes.begin(); it!=nodes.end(); ++it) {
00630 std::vector<CoverTreeNode*> children = (*it)->getChildren(i);
00631
00632
00633 for(it2=children.begin(); it2!=children.end(); ++it2) {
00634 double dist = (*it2)->distance((*it)->getPoint());
00635 if(dist>sep) {
00636 std::cout << "Level" << i << " covering tree invariant failed.n";
00637 return false;
00638 }
00639 }
00640 allChildren.insert
00641 (allChildren.end(),children.begin(),children.end());
00642 }
00643 nodes.insert(nodes.begin(),allChildren.begin(),allChildren.end());
00644 }
00645 return true;
00646 }
00647 }
00648 #endif // _COVER_TREE_H
00649