CoverTree.h

Go to the documentation of this file.
00001 /*
00002  * Copyright (C) 2011 by Singularity Institute for Artificial Intelligence
00003  * All Rights Reserved
00004  *
00005  * Written by David Crane <dncrane@gmail.com>
00006  *
00007  * This program is free software; you can redistribute it and/or modify
00008  * it under the terms of the GNU Affero General Public License v3 as
00009  * published by the Free Software Foundation and including the exceptions
00010  * at http://opencog.org/wiki/Licenses
00011  *
00012  * This program is distributed in the hope that it will be useful,
00013  * but WITHOUT ANY WARRANTY; without even the implied warranty of
00014  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00015  * GNU General Public License for more details.
00016  *
00017  * You should have received a copy of the GNU Affero General Public License
00018  * along with this program; if not, write to:
00019  * Free Software Foundation, Inc.,
00020  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
00021  *
00022  * Shogun modifications by Sergey Lisitsyn
00023  */
00024 
00025 #ifndef _COVER_TREE_H
00026 #define _COVER_TREE_H
00027 
00028 #include <vector>
00029 #include <algorithm>
00030 #include <map>
00031 #include <set>
00032 #include <cmath>
00033 #include <float.h>
00034 #include <iostream>
00035 
00036 namespace shogun
00037 {
00038 
00051 template<class Point>
00052 class CoverTree
00053 {
00058     class CoverTreeNode
00059     {
00060     private:
00061         //_childMap[i] is a vector of the node's children at level i
00062         std::map<int,std::vector<CoverTreeNode*> > _childMap;
00063         //_points is all of the points with distance 0 which are not equal.
00064         std::vector<Point> _points;
00065     public:
00066         CoverTreeNode(const Point& p);
00074         std::vector<CoverTreeNode*> getChildren(int level) const;
00075         void addChild(int level, CoverTreeNode* p);
00076         void removeChild(int level, CoverTreeNode* p);
00077         void addPoint(const Point& p);
00078         void removePoint(const Point& p);
00079         const std::vector<Point>& getPoints() { return _points; }
00080         double distance(const CoverTreeNode& p) const;
00081 
00082         bool isSingle() const;
00083         bool hasPoint(const Point& p) const;
00084 
00085         const Point& getPoint() const;
00086 
00091         std::vector<CoverTreeNode*> getAllChildren() const;
00092     }; // CoverTreeNode class
00093  private:
00094     typedef std::pair<double, CoverTreeNode*> distNodePair;
00095 
00096     CoverTreeNode* _root;
00097     unsigned int _numNodes;
00098     int _maxLevel;//base^_maxLevel should be the max distance
00099                   //between any 2 points
00100     int _minLevel;//A level beneath which there are no more new nodes.
00101 
00102     std::vector<CoverTreeNode*>
00103         kNearestNodes(const Point& p, const unsigned int& k) const;
00107     bool insert_rec(const Point& p,
00108                     const std::vector<distNodePair>& Qi,
00109                     const int& level);
00110 
00115     distNodePair distance(const Point& p,
00116                           const std::vector<CoverTreeNode*>& Q);
00117 
00118 
00119     void remove_rec(const Point& p,
00120                     std::map<int,std::vector<distNodePair> >& coverSets,
00121                     int level,
00122                     bool& multi);
00123 
00124  public:
00126     static const double base;
00127 
00137     CoverTree(const double& maxDist,
00138               const std::vector<Point>& points=std::vector<Point>());
00139     ~CoverTree();
00140 
00148     bool isValidTree() const;
00149 
00157     void insert(const Point& newPoint);
00158 
00164     void remove(const Point& p);
00165 
00171     std::vector<Point> kNearestNeighbors(const Point& p, const unsigned int& k) const;
00172 
00177     CoverTreeNode* getRoot() const;
00178 
00179 }; // CoverTree class
00180 
00181 template<class Point>
00182 const double CoverTree<Point>::base = 2.0;
00183 
00184 template<class Point>
00185 CoverTree<Point>::CoverTree(const double& maxDist,
00186                             const std::vector<Point>& points)
00187 {
00188     _root=NULL;
00189     _numNodes=0;
00190     _maxLevel=ceilf(log(maxDist)/log(base));
00191     _minLevel=_maxLevel-1;
00192     typename std::vector<Point>::const_iterator it;
00193     for(it=points.begin(); it!=points.end(); ++it) {
00194         this->insert(*it);
00195     }
00196 }
00197 
00198 template<class Point>
00199 CoverTree<Point>::~CoverTree()
00200 {
00201     if(_root==NULL) return;
00202     //Get all of the root's children (from any level),
00203     //delete the root, repeat for each of the children
00204     std::vector<CoverTreeNode*> nodes;
00205     nodes.push_back(_root);
00206     while(!nodes.empty()) {
00207         CoverTreeNode* byeNode = nodes[0];
00208         nodes.erase(nodes.begin());
00209         std::vector<CoverTreeNode*> children = byeNode->getAllChildren();
00210         nodes.insert(nodes.begin(),children.begin(),children.end());
00211         //std::cout << _numNodes << "\n";
00212         delete byeNode;
00213         //_numNodes--;
00214     }
00215 }
00216 
00217 template<class Point>
00218 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00219 CoverTree<Point>::kNearestNodes(const Point& p, const unsigned int& k) const
00220 {
00221     if(_root==NULL) return std::vector<CoverTreeNode*>();
00222     //maxDist is the kth nearest known point to p, and also the farthest
00223     //point from p in the set minNodes defined below.
00224     double maxDist = p.distance(_root->getPoint());
00225     //minNodes stores the k nearest known points to p.
00226     std::set<distNodePair> minNodes;
00227 
00228     minNodes.insert(std::make_pair(maxDist,_root));
00229     std::vector<distNodePair> Qj(1,std::make_pair(maxDist,_root));
00230     for(int level = _maxLevel; level>=_minLevel;level--) {
00231         typename std::vector<distNodePair>::const_iterator it;
00232         int size = Qj.size();
00233         for(int i=0; i<size; i++) {
00234             std::vector<CoverTreeNode*> children =
00235                 Qj[i].second->getChildren(level);
00236             typename std::vector<CoverTreeNode*>::const_iterator it2;
00237             for(it2=children.begin(); it2!=children.end(); ++it2) {
00238                 double d = p.distance((*it2)->getPoint());
00239                 if(d < maxDist || minNodes.size() < k) {
00240                     minNodes.insert(std::make_pair(d,*it2));
00241                     //--minNodes.end() gives us an iterator to the greatest
00242                     //element of minNodes.
00243                     if(minNodes.size() > k) minNodes.erase(--minNodes.end());
00244                     maxDist = (--minNodes.end())->first;
00245                 }
00246                 Qj.push_back(std::make_pair(d,*it2));
00247             }
00248         }
00249         double sep = maxDist + pow(base, level);
00250         size = Qj.size();
00251         for(int i=0; i<size; i++) {
00252             if(Qj[i].first > sep) {
00253                 //quickly removes an element from a vector w/o preserving order.
00254                 Qj[i]=Qj.back();
00255                 Qj.pop_back();
00256                 size--; i--;
00257             }
00258         }
00259     }
00260     std::vector<CoverTreeNode*> kNN;
00261     typename std::set<distNodePair>::const_iterator it;
00262     for(it=minNodes.begin();it!=minNodes.end();++it) {
00263         kNN.push_back(it->second);
00264     }
00265     return kNN;
00266 }
00267 template<class Point>
00268 bool CoverTree<Point>::insert_rec(const Point& p,
00269                                   const std::vector<distNodePair>& Qi,
00270                                   const int& level)
00271 {
00272     std::vector<std::pair<double, CoverTreeNode*> > Qj;
00273     double sep = pow(base,level);
00274     double minDist = DBL_MAX;
00275     std::pair<double,CoverTreeNode*> minQiDist(DBL_MAX,NULL);
00276     typename  std::vector<std::pair<double, CoverTreeNode*> >::const_iterator it;
00277     for(it=Qi.begin(); it!=Qi.end(); ++it) {
00278         if(it->first<minQiDist.first) minQiDist = *it;
00279         if(it->first<minDist) minDist=it->first;
00280         if(it->first<=sep) Qj.push_back(*it);
00281         std::vector<CoverTreeNode*> children = it->second->getChildren(level);
00282         typename std::vector<CoverTreeNode*>::const_iterator it2;
00283         for(it2=children.begin();it2!=children.end();++it2) {
00284             double d = p.distance((*it2)->getPoint());
00285             if(d<minDist) minDist = d;
00286             if(d<=sep) {
00287                 Qj.push_back(std::make_pair(d,*it2));
00288             }
00289         }
00290     }
00291     //std::cout << "level: " << level << ", sep: " << sep << ", dist: " << minQDist.first << "\n";
00292     if(minDist > sep) {
00293         return true;
00294     } else {
00295         bool found = insert_rec(p,Qj,level-1);
00296         //distNodePair minQiDist = distance(p,Qi);
00297         if(found && minQiDist.first <= sep) {
00298             if(level-1<_minLevel) _minLevel=level-1;
00299             minQiDist.second->addChild(level,
00300                                        new CoverTreeNode(p));
00301             //std::cout << "parent is ";
00302             //minQiDist.second->getPoint().print();
00303             _numNodes++;
00304             return false;
00305         } else {
00306             return found;
00307         }
00308     }
00309 }
00310 
00311 template<class Point>
00312 void CoverTree<Point>::remove_rec(const Point& p,
00313                                   std::map<int,std::vector<distNodePair> >& coverSets,
00314                                   int level,
00315                                   bool& multi)
00316 {
00317     std::vector<distNodePair>& Qi = coverSets[level];
00318     std::vector<distNodePair>& Qj = coverSets[level-1];
00319     double minDist = DBL_MAX;
00320     CoverTreeNode* minNode = _root;
00321     CoverTreeNode* parent = 0;
00322     double sep = pow(base, level);
00323     typename std::vector<distNodePair>::const_iterator it_;
00324     //set Qj to be all children q of Qi such that p.distance(q)<=sep
00325     //and also keep track of the minimum distance from p to a node in Qj
00326     //note that every node has itself as a child, but the
00327     //getChildren function only returns non-self-children.
00328     for(it_=Qi.begin();it_!=Qi.end();++it_) {
00329         std::vector<CoverTreeNode*> children = it_->second->getChildren(level);
00330         double dist = it_->first;
00331         if(dist<minDist) {
00332             minDist = dist;
00333             minNode = it_->second;
00334         }
00335         if(dist <= sep) {
00336             Qj.push_back(*it_);
00337         }
00338         typename std::vector<CoverTreeNode*>::const_iterator it2;
00339         for(it2=children.begin();it2!=children.end();++it2) {
00340             dist = p.distance((*it2)->getPoint());
00341             if(dist<minDist) {
00342                 minDist = dist;
00343                 minNode = *it2;
00344                 if(dist == 0.0) parent = it_->second;
00345             }
00346             if(dist <= sep) {
00347                 Qj.push_back(std::make_pair(dist,*it2));
00348             }
00349         }
00350     }
00351     if(level>_minLevel) remove_rec(p,coverSets,level-1,multi);
00352     if(minNode->hasPoint(p)) {
00353         //the multi flag indicates the point we removed is from a
00354         //node containing multiple points, and we have removed it,
00355         //so we don't need to do anything else.
00356         if(multi) return;
00357         if(!minNode->isSingle()) {
00358             minNode->removePoint(p);
00359             multi=true;
00360             return;
00361         }
00362         if(parent!=NULL) parent->removeChild(level, minNode);
00363         std::vector<CoverTreeNode*> children = minNode->getChildren(level-1);
00364         std::vector<distNodePair>& Q = coverSets[level-1];
00365         if(Q.size()==1 && Q[0].second==minNode) {
00366             Q.pop_back();
00367         } else {
00368             for(unsigned int i=0;i<Q.size();i++) {
00369                 if(Q[i].second==minNode) {
00370                     Q[i]=Q.back();
00371                     Q.pop_back();
00372                     break;
00373                 }
00374             }
00375         }
00376         typename std::vector<CoverTreeNode*>::const_iterator it;
00377         for(it=children.begin();it!=children.end();++it) {
00378             int i = level-1;
00379             Point q = (*it)->getPoint();
00380             double minDQ = DBL_MAX;
00381             CoverTreeNode* minDQNode;
00382             double sep_ = pow(base,i);
00383             bool br=false;
00384             while(true) {
00385                 std::vector<distNodePair>&
00386                     Q_ = coverSets[i];
00387                 typename std::vector<distNodePair>::const_iterator it2;
00388                 minDQ = DBL_MAX;
00389                 for(it2=Q_.begin();it2!=Q_.end();++it2) {
00390                     double d = q.distance(it2->second->getPoint());
00391                     if(d<minDQ) {
00392                         minDQ = d;
00393                         minDQNode = it2->second;
00394                         if(d <=sep_) {
00395                             br=true;
00396                             break;
00397                         }
00398                     }
00399                 }
00400                 minDQ=DBL_MAX;
00401                 if(br) break;
00402                 Q_.push_back(std::make_pair((*it)->distance(p),*it));
00403                 i++;
00404                 sep_ = pow(base,i);
00405             }
00406             //minDQNode->getPoint().print();
00407             //std::cout << " is level " << i << " parent of ";
00408             //(*it)->getPoint().print();
00409             minDQNode->addChild(i,*it);
00410         }
00411         if(parent!=NULL) {
00412             delete minNode;
00413             _numNodes--;
00414         }
00415     }
00416 }
00417 
00418 template<class Point>
00419 std::pair<double, typename CoverTree<Point>::CoverTreeNode*>
00420 CoverTree<Point>::distance(const Point& p,
00421                            const std::vector<CoverTreeNode*>& Q)
00422 {
00423     double minDist = DBL_MAX;
00424     CoverTreeNode* minNode;
00425     typename std::vector<CoverTreeNode*>::const_iterator it;
00426     for(it=Q.begin();it!=Q.end();++it) {
00427         double dist = p.distance((*it)->getPoint());
00428         if(dist < minDist) {
00429             minDist = dist;
00430             minNode = *it;
00431         }
00432     }
00433     return std::make_pair(minDist,minNode);
00434 }
00435 
00436 template<class Point>
00437 void CoverTree<Point>::insert(const Point& newPoint)
00438 {
00439     if(_root==NULL) {
00440         _root = new CoverTreeNode(newPoint);
00441         _numNodes=1;
00442         return;
00443     }
00444     //TODO: this is pretty inefficient, there may be a better way
00445     //to check if the node already exists...
00446     CoverTreeNode* n = kNearestNodes(newPoint,1)[0];
00447     if(newPoint.distance(n->getPoint())==0.0) {
00448         n->addPoint(newPoint);
00449     } else {
00450         //insert_rec acts under the assumption that there are no nodes with
00451         //distance 0 to newPoint in the cover tree (the previous lines check it)
00452         insert_rec(newPoint,
00453                    std::vector<distNodePair>
00454                    (1,std::make_pair(_root->distance(newPoint),_root)),
00455                    _maxLevel);
00456     }
00457 }
00458 
00459 template<class Point>
00460 void CoverTree<Point>::remove(const Point& p)
00461 {
00462     //Most of this function's code is for the special case of removing the root
00463     if(_root==NULL) return;
00464     bool removingRoot=_root->hasPoint(p);
00465     if(removingRoot && !_root->isSingle()) {
00466         _root->removePoint(p);
00467         return;
00468     }
00469     CoverTreeNode* newRoot=NULL;
00470     if(removingRoot) {
00471         if(_numNodes==1) {
00472             //removing the last node...
00473             delete _root;
00474             _numNodes--;
00475             _root=NULL;
00476             return;
00477         } else {
00478             for(int i=_maxLevel;i>_minLevel;i--) {
00479                 if(!(_root->getChildren(i).empty())) {
00480                     newRoot = _root->getChildren(i).back();
00481                     _root->removeChild(i,newRoot);
00482                     break;
00483                 }
00484             }
00485         }
00486     }
00487     std::map<int, std::vector<distNodePair> > coverSets;
00488     coverSets[_maxLevel].push_back(std::make_pair(_root->distance(p),_root));
00489     if(removingRoot)
00490         coverSets[_maxLevel].push_back(std::make_pair(newRoot->distance(p),newRoot));
00491     bool multi = false;
00492     remove_rec(p,coverSets,_maxLevel,multi);
00493     if(removingRoot) {
00494         delete _root;
00495         _numNodes--;
00496         _root=newRoot;
00497     }
00498 }
00499 
00500 template<class Point>
00501 std::vector<Point> CoverTree<Point>::kNearestNeighbors(const Point& p,
00502                                                        const unsigned int& k) const
00503 {
00504     if(_root==NULL) return std::vector<Point>();
00505     std::vector<CoverTreeNode*> v = kNearestNodes(p, k);
00506     std::vector<Point> kNN;
00507     typename std::vector<CoverTreeNode*>::const_iterator it;
00508     for(it=v.begin();it!=v.end();++it) {
00509         const std::vector<Point>& po = (*it)->getPoints();
00510         kNN.insert(kNN.end(),po.begin(),po.end());
00511         if(kNN.size() >= k) break;
00512     }
00513     return kNN;
00514 }
00515 
00516 template<class Point>
00517 typename CoverTree<Point>::CoverTreeNode* CoverTree<Point>::getRoot() const
00518 {
00519     return _root;
00520 }
00521 
00522 template<class Point>
00523 CoverTree<Point>::CoverTreeNode::CoverTreeNode(const Point& p) {
00524     _points.push_back(p);
00525 }
00526 
00527 template<class Point>
00528 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00529 CoverTree<Point>::CoverTreeNode::getChildren(int level) const
00530 {
00531     typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator
00532         it = _childMap.find(level);
00533     if(it!=_childMap.end()) {
00534         return it->second;
00535     }
00536     return std::vector<CoverTreeNode*>();
00537 }
00538 
00539 template<class Point>
00540 void CoverTree<Point>::CoverTreeNode::addChild(int level, CoverTreeNode* p)
00541 {
00542     _childMap[level].push_back(p);
00543 }
00544 
00545 template<class Point>
00546 void CoverTree<Point>::CoverTreeNode::removeChild(int level, CoverTreeNode* p)
00547 {
00548     std::vector<CoverTreeNode*>& v = _childMap[level];
00549     for(unsigned int i=0;i<v.size();i++) {
00550         if(v[i]==p) {
00551             v[i]=v.back();
00552             v.pop_back();
00553             break;
00554         }
00555     }
00556 }
00557 
00558 template<class Point>
00559 void CoverTree<Point>::CoverTreeNode::addPoint(const Point& p)
00560 {
00561     if(find(_points.begin(), _points.end(), p) == _points.end())
00562         _points.push_back(p);
00563 }
00564 
00565 template<class Point>
00566 void CoverTree<Point>::CoverTreeNode::removePoint(const Point& p)
00567 {
00568     typename std::vector<Point>::iterator it =
00569         find(_points.begin(), _points.end(), p);
00570     if(it != _points.end())
00571         _points.erase(it);
00572 }
00573 
00574 template<class Point>
00575 double CoverTree<Point>::CoverTreeNode::distance(const CoverTreeNode& p) const
00576 {
00577     return _points[0].distance(p.getPoint());
00578 }
00579 
00580 template<class Point>
00581 bool CoverTree<Point>::CoverTreeNode::isSingle() const
00582 {
00583     return _points.size() == 1;
00584 }
00585 
00586 template<class Point>
00587 bool CoverTree<Point>::CoverTreeNode::hasPoint(const Point& p) const
00588 {
00589     return find(_points.begin(), _points.end(), p) != _points.end();
00590 }
00591 
00592 template<class Point>
00593 const Point& CoverTree<Point>::CoverTreeNode::getPoint() const { return _points[0]; }
00594 
00595 template<class Point>
00596 std::vector<typename CoverTree<Point>::CoverTreeNode*>
00597 CoverTree<Point>::CoverTreeNode::getAllChildren() const
00598 {
00599     std::vector<CoverTreeNode*> children;
00600     typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator it;
00601     for(it=_childMap.begin();it!=_childMap.end();++it) {
00602         children.insert(children.end(), it->second.begin(), it->second.end());
00603     }
00604     return children;
00605 }
00606 
00607 template<class Point>
00608 bool CoverTree<Point>::isValidTree() const {
00609     if(_numNodes==0)
00610         return _root==NULL;
00611 
00612     std::vector<CoverTreeNode*> nodes;
00613     nodes.push_back(_root);
00614     for(int i=_maxLevel;i>_minLevel;i--) {
00615         double sep = pow(base,i);
00616         typename std::vector<CoverTreeNode*>::const_iterator it, it2;
00617         //verify separation invariant of cover tree: for each level,
00618         //every point is farther than base^level away
00619         for(it=nodes.begin(); it!=nodes.end(); ++it) {
00620             for(it2=nodes.begin(); it2!=nodes.end(); ++it2) {
00621                 double dist=(*it)->distance((*it2)->getPoint());
00622                 if(dist<=sep && dist!=0.0) {
00623                     std::cout << "Level " << i << " Separation invariant failed.\n";
00624                     return false;
00625                 }
00626             }
00627         }
00628         std::vector<CoverTreeNode*> allChildren;
00629         for(it=nodes.begin(); it!=nodes.end(); ++it) {
00630             std::vector<CoverTreeNode*> children = (*it)->getChildren(i);
00631             //verify covering tree invariant: the children of node n at level
00632             //i are no further than base^i away
00633             for(it2=children.begin(); it2!=children.end(); ++it2) {
00634                 double dist = (*it2)->distance((*it)->getPoint());
00635                 if(dist>sep) {
00636                     std::cout << "Level" << i << " covering tree invariant failed.n";
00637                     return false;
00638                 }
00639             }
00640             allChildren.insert
00641                 (allChildren.end(),children.begin(),children.end());
00642         }
00643         nodes.insert(nodes.begin(),allChildren.begin(),allChildren.end());
00644     }
00645     return true;
00646 }
00647 }
00648 #endif // _COVER_TREE_H
00649 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation