SHOGUN
v2.0.0
|
00001 /* 00002 * Copyright (C) 2011 by Singularity Institute for Artificial Intelligence 00003 * All Rights Reserved 00004 * 00005 * Written by David Crane <dncrane@gmail.com> 00006 * 00007 * This program is free software; you can redistribute it and/or modify 00008 * it under the terms of the GNU Affero General Public License v3 as 00009 * published by the Free Software Foundation and including the exceptions 00010 * at http://opencog.org/wiki/Licenses 00011 * 00012 * This program is distributed in the hope that it will be useful, 00013 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00014 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00015 * GNU General Public License for more details. 00016 * 00017 * You should have received a copy of the GNU Affero General Public License 00018 * along with this program; if not, write to: 00019 * Free Software Foundation, Inc., 00020 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 00021 * 00022 * Shogun modifications by Sergey Lisitsyn 00023 */ 00024 00025 #ifndef _COVER_TREE_H 00026 #define _COVER_TREE_H 00027 00028 #include <vector> 00029 #include <algorithm> 00030 #include <map> 00031 #include <set> 00032 #include <cmath> 00033 #include <float.h> 00034 #include <iostream> 00035 00036 namespace shogun 00037 { 00038 00051 template<class Point> 00052 class CoverTree 00053 { 00058 class CoverTreeNode 00059 { 00060 private: 00061 //_childMap[i] is a vector of the node's children at level i 00062 std::map<int,std::vector<CoverTreeNode*> > _childMap; 00063 //_points is all of the points with distance 0 which are not equal. 00064 std::vector<Point> _points; 00065 public: 00066 CoverTreeNode(const Point& p); 00074 std::vector<CoverTreeNode*> getChildren(int level) const; 00075 void addChild(int level, CoverTreeNode* p); 00076 void removeChild(int level, CoverTreeNode* p); 00077 void addPoint(const Point& p); 00078 void removePoint(const Point& p); 00079 const std::vector<Point>& getPoints() { return _points; } 00080 double distance(const CoverTreeNode& p) const; 00081 00082 bool isSingle() const; 00083 bool hasPoint(const Point& p) const; 00084 00085 const Point& getPoint() const; 00086 00091 std::vector<CoverTreeNode*> getAllChildren() const; 00092 }; // CoverTreeNode class 00093 private: 00094 typedef std::pair<double, CoverTreeNode*> distNodePair; 00095 00096 CoverTreeNode* _root; 00097 unsigned int _numNodes; 00098 int _maxLevel;//base^_maxLevel should be the max distance 00099 //between any 2 points 00100 int _minLevel;//A level beneath which there are no more new nodes. 00101 00102 std::vector<CoverTreeNode*> 00103 kNearestNodes(const Point& p, const unsigned int& k) const; 00107 bool insert_rec(const Point& p, 00108 const std::vector<distNodePair>& Qi, 00109 const int& level); 00110 00115 distNodePair distance(const Point& p, 00116 const std::vector<CoverTreeNode*>& Q); 00117 00118 00119 void remove_rec(const Point& p, 00120 std::map<int,std::vector<distNodePair> >& coverSets, 00121 int level, 00122 bool& multi); 00123 00124 public: 00126 static const double base; 00127 00137 CoverTree(const double& maxDist, 00138 const std::vector<Point>& points=std::vector<Point>()); 00139 ~CoverTree(); 00140 00148 bool isValidTree() const; 00149 00157 void insert(const Point& newPoint); 00158 00164 void remove(const Point& p); 00165 00171 std::vector<Point> kNearestNeighbors(const Point& p, const unsigned int& k) const; 00172 00177 CoverTreeNode* getRoot() const; 00178 00179 }; // CoverTree class 00180 00181 template<class Point> 00182 const double CoverTree<Point>::base = 2.0; 00183 00184 template<class Point> 00185 CoverTree<Point>::CoverTree(const double& maxDist, 00186 const std::vector<Point>& points) 00187 { 00188 _root=NULL; 00189 _numNodes=0; 00190 _maxLevel=ceilf(log(maxDist)/log(base)); 00191 _minLevel=_maxLevel-1; 00192 typename std::vector<Point>::const_iterator it; 00193 for(it=points.begin(); it!=points.end(); ++it) { 00194 this->insert(*it); 00195 } 00196 } 00197 00198 template<class Point> 00199 CoverTree<Point>::~CoverTree() 00200 { 00201 if(_root==NULL) return; 00202 //Get all of the root's children (from any level), 00203 //delete the root, repeat for each of the children 00204 std::vector<CoverTreeNode*> nodes; 00205 nodes.push_back(_root); 00206 while(!nodes.empty()) { 00207 CoverTreeNode* byeNode = nodes[0]; 00208 nodes.erase(nodes.begin()); 00209 std::vector<CoverTreeNode*> children = byeNode->getAllChildren(); 00210 nodes.insert(nodes.begin(),children.begin(),children.end()); 00211 //std::cout << _numNodes << "\n"; 00212 delete byeNode; 00213 //_numNodes--; 00214 } 00215 } 00216 00217 template<class Point> 00218 std::vector<typename CoverTree<Point>::CoverTreeNode*> 00219 CoverTree<Point>::kNearestNodes(const Point& p, const unsigned int& k) const 00220 { 00221 if(_root==NULL) return std::vector<CoverTreeNode*>(); 00222 //maxDist is the kth nearest known point to p, and also the farthest 00223 //point from p in the set minNodes defined below. 00224 double maxDist = p.distance(_root->getPoint()); 00225 //minNodes stores the k nearest known points to p. 00226 std::set<distNodePair> minNodes; 00227 00228 minNodes.insert(std::make_pair(maxDist,_root)); 00229 std::vector<distNodePair> Qj(1,std::make_pair(maxDist,_root)); 00230 for(int level = _maxLevel; level>=_minLevel;level--) { 00231 typename std::vector<distNodePair>::const_iterator it; 00232 int size = Qj.size(); 00233 for(int i=0; i<size; i++) { 00234 std::vector<CoverTreeNode*> children = 00235 Qj[i].second->getChildren(level); 00236 typename std::vector<CoverTreeNode*>::const_iterator it2; 00237 for(it2=children.begin(); it2!=children.end(); ++it2) { 00238 double d = p.distance((*it2)->getPoint()); 00239 if(d < maxDist || minNodes.size() < k) { 00240 minNodes.insert(std::make_pair(d,*it2)); 00241 //--minNodes.end() gives us an iterator to the greatest 00242 //element of minNodes. 00243 if(minNodes.size() > k) minNodes.erase(--minNodes.end()); 00244 maxDist = (--minNodes.end())->first; 00245 } 00246 Qj.push_back(std::make_pair(d,*it2)); 00247 } 00248 } 00249 double sep = maxDist + pow(base, level); 00250 size = Qj.size(); 00251 for(int i=0; i<size; i++) { 00252 if(Qj[i].first > sep) { 00253 //quickly removes an element from a vector w/o preserving order. 00254 Qj[i]=Qj.back(); 00255 Qj.pop_back(); 00256 size--; i--; 00257 } 00258 } 00259 } 00260 std::vector<CoverTreeNode*> kNN; 00261 typename std::set<distNodePair>::const_iterator it; 00262 for(it=minNodes.begin();it!=minNodes.end();++it) { 00263 kNN.push_back(it->second); 00264 } 00265 return kNN; 00266 } 00267 template<class Point> 00268 bool CoverTree<Point>::insert_rec(const Point& p, 00269 const std::vector<distNodePair>& Qi, 00270 const int& level) 00271 { 00272 std::vector<std::pair<double, CoverTreeNode*> > Qj; 00273 double sep = pow(base,level); 00274 double minDist = DBL_MAX; 00275 std::pair<double,CoverTreeNode*> minQiDist(DBL_MAX,NULL); 00276 typename std::vector<std::pair<double, CoverTreeNode*> >::const_iterator it; 00277 for(it=Qi.begin(); it!=Qi.end(); ++it) { 00278 if(it->first<minQiDist.first) minQiDist = *it; 00279 if(it->first<minDist) minDist=it->first; 00280 if(it->first<=sep) Qj.push_back(*it); 00281 std::vector<CoverTreeNode*> children = it->second->getChildren(level); 00282 typename std::vector<CoverTreeNode*>::const_iterator it2; 00283 for(it2=children.begin();it2!=children.end();++it2) { 00284 double d = p.distance((*it2)->getPoint()); 00285 if(d<minDist) minDist = d; 00286 if(d<=sep) { 00287 Qj.push_back(std::make_pair(d,*it2)); 00288 } 00289 } 00290 } 00291 //std::cout << "level: " << level << ", sep: " << sep << ", dist: " << minQDist.first << "\n"; 00292 if(minDist > sep) { 00293 return true; 00294 } else { 00295 bool found = insert_rec(p,Qj,level-1); 00296 //distNodePair minQiDist = distance(p,Qi); 00297 if(found && minQiDist.first <= sep) { 00298 if(level-1<_minLevel) _minLevel=level-1; 00299 minQiDist.second->addChild(level, 00300 new CoverTreeNode(p)); 00301 //std::cout << "parent is "; 00302 //minQiDist.second->getPoint().print(); 00303 _numNodes++; 00304 return false; 00305 } else { 00306 return found; 00307 } 00308 } 00309 } 00310 00311 template<class Point> 00312 void CoverTree<Point>::remove_rec(const Point& p, 00313 std::map<int,std::vector<distNodePair> >& coverSets, 00314 int level, 00315 bool& multi) 00316 { 00317 std::vector<distNodePair>& Qi = coverSets[level]; 00318 std::vector<distNodePair>& Qj = coverSets[level-1]; 00319 double minDist = DBL_MAX; 00320 CoverTreeNode* minNode = _root; 00321 CoverTreeNode* parent = 0; 00322 double sep = pow(base, level); 00323 typename std::vector<distNodePair>::const_iterator it_; 00324 //set Qj to be all children q of Qi such that p.distance(q)<=sep 00325 //and also keep track of the minimum distance from p to a node in Qj 00326 //note that every node has itself as a child, but the 00327 //getChildren function only returns non-self-children. 00328 for(it_=Qi.begin();it_!=Qi.end();++it_) { 00329 std::vector<CoverTreeNode*> children = it_->second->getChildren(level); 00330 double dist = it_->first; 00331 if(dist<minDist) { 00332 minDist = dist; 00333 minNode = it_->second; 00334 } 00335 if(dist <= sep) { 00336 Qj.push_back(*it_); 00337 } 00338 typename std::vector<CoverTreeNode*>::const_iterator it2; 00339 for(it2=children.begin();it2!=children.end();++it2) { 00340 dist = p.distance((*it2)->getPoint()); 00341 if(dist<minDist) { 00342 minDist = dist; 00343 minNode = *it2; 00344 if(dist == 0.0) parent = it_->second; 00345 } 00346 if(dist <= sep) { 00347 Qj.push_back(std::make_pair(dist,*it2)); 00348 } 00349 } 00350 } 00351 if(level>_minLevel) remove_rec(p,coverSets,level-1,multi); 00352 if(minNode->hasPoint(p)) { 00353 //the multi flag indicates the point we removed is from a 00354 //node containing multiple points, and we have removed it, 00355 //so we don't need to do anything else. 00356 if(multi) return; 00357 if(!minNode->isSingle()) { 00358 minNode->removePoint(p); 00359 multi=true; 00360 return; 00361 } 00362 if(parent!=NULL) parent->removeChild(level, minNode); 00363 std::vector<CoverTreeNode*> children = minNode->getChildren(level-1); 00364 std::vector<distNodePair>& Q = coverSets[level-1]; 00365 if(Q.size()==1 && Q[0].second==minNode) { 00366 Q.pop_back(); 00367 } else { 00368 for(unsigned int i=0;i<Q.size();i++) { 00369 if(Q[i].second==minNode) { 00370 Q[i]=Q.back(); 00371 Q.pop_back(); 00372 break; 00373 } 00374 } 00375 } 00376 typename std::vector<CoverTreeNode*>::const_iterator it; 00377 for(it=children.begin();it!=children.end();++it) { 00378 int i = level-1; 00379 Point q = (*it)->getPoint(); 00380 double minDQ = DBL_MAX; 00381 CoverTreeNode* minDQNode; 00382 double sep_ = pow(base,i); 00383 bool br=false; 00384 while(true) { 00385 std::vector<distNodePair>& 00386 Q_ = coverSets[i]; 00387 typename std::vector<distNodePair>::const_iterator it2; 00388 minDQ = DBL_MAX; 00389 for(it2=Q_.begin();it2!=Q_.end();++it2) { 00390 double d = q.distance(it2->second->getPoint()); 00391 if(d<minDQ) { 00392 minDQ = d; 00393 minDQNode = it2->second; 00394 if(d <=sep_) { 00395 br=true; 00396 break; 00397 } 00398 } 00399 } 00400 minDQ=DBL_MAX; 00401 if(br) break; 00402 Q_.push_back(std::make_pair((*it)->distance(p),*it)); 00403 i++; 00404 sep_ = pow(base,i); 00405 } 00406 //minDQNode->getPoint().print(); 00407 //std::cout << " is level " << i << " parent of "; 00408 //(*it)->getPoint().print(); 00409 minDQNode->addChild(i,*it); 00410 } 00411 if(parent!=NULL) { 00412 delete minNode; 00413 _numNodes--; 00414 } 00415 } 00416 } 00417 00418 template<class Point> 00419 std::pair<double, typename CoverTree<Point>::CoverTreeNode*> 00420 CoverTree<Point>::distance(const Point& p, 00421 const std::vector<CoverTreeNode*>& Q) 00422 { 00423 double minDist = DBL_MAX; 00424 CoverTreeNode* minNode; 00425 typename std::vector<CoverTreeNode*>::const_iterator it; 00426 for(it=Q.begin();it!=Q.end();++it) { 00427 double dist = p.distance((*it)->getPoint()); 00428 if(dist < minDist) { 00429 minDist = dist; 00430 minNode = *it; 00431 } 00432 } 00433 return std::make_pair(minDist,minNode); 00434 } 00435 00436 template<class Point> 00437 void CoverTree<Point>::insert(const Point& newPoint) 00438 { 00439 if(_root==NULL) { 00440 _root = new CoverTreeNode(newPoint); 00441 _numNodes=1; 00442 return; 00443 } 00444 //TODO: this is pretty inefficient, there may be a better way 00445 //to check if the node already exists... 00446 CoverTreeNode* n = kNearestNodes(newPoint,1)[0]; 00447 if(newPoint.distance(n->getPoint())==0.0) { 00448 n->addPoint(newPoint); 00449 } else { 00450 //insert_rec acts under the assumption that there are no nodes with 00451 //distance 0 to newPoint in the cover tree (the previous lines check it) 00452 insert_rec(newPoint, 00453 std::vector<distNodePair> 00454 (1,std::make_pair(_root->distance(newPoint),_root)), 00455 _maxLevel); 00456 } 00457 } 00458 00459 template<class Point> 00460 void CoverTree<Point>::remove(const Point& p) 00461 { 00462 //Most of this function's code is for the special case of removing the root 00463 if(_root==NULL) return; 00464 bool removingRoot=_root->hasPoint(p); 00465 if(removingRoot && !_root->isSingle()) { 00466 _root->removePoint(p); 00467 return; 00468 } 00469 CoverTreeNode* newRoot=NULL; 00470 if(removingRoot) { 00471 if(_numNodes==1) { 00472 //removing the last node... 00473 delete _root; 00474 _numNodes--; 00475 _root=NULL; 00476 return; 00477 } else { 00478 for(int i=_maxLevel;i>_minLevel;i--) { 00479 if(!(_root->getChildren(i).empty())) { 00480 newRoot = _root->getChildren(i).back(); 00481 _root->removeChild(i,newRoot); 00482 break; 00483 } 00484 } 00485 } 00486 } 00487 std::map<int, std::vector<distNodePair> > coverSets; 00488 coverSets[_maxLevel].push_back(std::make_pair(_root->distance(p),_root)); 00489 if(removingRoot) 00490 coverSets[_maxLevel].push_back(std::make_pair(newRoot->distance(p),newRoot)); 00491 bool multi = false; 00492 remove_rec(p,coverSets,_maxLevel,multi); 00493 if(removingRoot) { 00494 delete _root; 00495 _numNodes--; 00496 _root=newRoot; 00497 } 00498 } 00499 00500 template<class Point> 00501 std::vector<Point> CoverTree<Point>::kNearestNeighbors(const Point& p, 00502 const unsigned int& k) const 00503 { 00504 if(_root==NULL) return std::vector<Point>(); 00505 std::vector<CoverTreeNode*> v = kNearestNodes(p, k); 00506 std::vector<Point> kNN; 00507 typename std::vector<CoverTreeNode*>::const_iterator it; 00508 for(it=v.begin();it!=v.end();++it) { 00509 const std::vector<Point>& po = (*it)->getPoints(); 00510 kNN.insert(kNN.end(),po.begin(),po.end()); 00511 if(kNN.size() >= k) break; 00512 } 00513 return kNN; 00514 } 00515 00516 template<class Point> 00517 typename CoverTree<Point>::CoverTreeNode* CoverTree<Point>::getRoot() const 00518 { 00519 return _root; 00520 } 00521 00522 template<class Point> 00523 CoverTree<Point>::CoverTreeNode::CoverTreeNode(const Point& p) { 00524 _points.push_back(p); 00525 } 00526 00527 template<class Point> 00528 std::vector<typename CoverTree<Point>::CoverTreeNode*> 00529 CoverTree<Point>::CoverTreeNode::getChildren(int level) const 00530 { 00531 typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator 00532 it = _childMap.find(level); 00533 if(it!=_childMap.end()) { 00534 return it->second; 00535 } 00536 return std::vector<CoverTreeNode*>(); 00537 } 00538 00539 template<class Point> 00540 void CoverTree<Point>::CoverTreeNode::addChild(int level, CoverTreeNode* p) 00541 { 00542 _childMap[level].push_back(p); 00543 } 00544 00545 template<class Point> 00546 void CoverTree<Point>::CoverTreeNode::removeChild(int level, CoverTreeNode* p) 00547 { 00548 std::vector<CoverTreeNode*>& v = _childMap[level]; 00549 for(unsigned int i=0;i<v.size();i++) { 00550 if(v[i]==p) { 00551 v[i]=v.back(); 00552 v.pop_back(); 00553 break; 00554 } 00555 } 00556 } 00557 00558 template<class Point> 00559 void CoverTree<Point>::CoverTreeNode::addPoint(const Point& p) 00560 { 00561 if(find(_points.begin(), _points.end(), p) == _points.end()) 00562 _points.push_back(p); 00563 } 00564 00565 template<class Point> 00566 void CoverTree<Point>::CoverTreeNode::removePoint(const Point& p) 00567 { 00568 typename std::vector<Point>::iterator it = 00569 find(_points.begin(), _points.end(), p); 00570 if(it != _points.end()) 00571 _points.erase(it); 00572 } 00573 00574 template<class Point> 00575 double CoverTree<Point>::CoverTreeNode::distance(const CoverTreeNode& p) const 00576 { 00577 return _points[0].distance(p.getPoint()); 00578 } 00579 00580 template<class Point> 00581 bool CoverTree<Point>::CoverTreeNode::isSingle() const 00582 { 00583 return _points.size() == 1; 00584 } 00585 00586 template<class Point> 00587 bool CoverTree<Point>::CoverTreeNode::hasPoint(const Point& p) const 00588 { 00589 return find(_points.begin(), _points.end(), p) != _points.end(); 00590 } 00591 00592 template<class Point> 00593 const Point& CoverTree<Point>::CoverTreeNode::getPoint() const { return _points[0]; } 00594 00595 template<class Point> 00596 std::vector<typename CoverTree<Point>::CoverTreeNode*> 00597 CoverTree<Point>::CoverTreeNode::getAllChildren() const 00598 { 00599 std::vector<CoverTreeNode*> children; 00600 typename std::map<int,std::vector<CoverTreeNode*> >::const_iterator it; 00601 for(it=_childMap.begin();it!=_childMap.end();++it) { 00602 children.insert(children.end(), it->second.begin(), it->second.end()); 00603 } 00604 return children; 00605 } 00606 00607 template<class Point> 00608 bool CoverTree<Point>::isValidTree() const { 00609 if(_numNodes==0) 00610 return _root==NULL; 00611 00612 std::vector<CoverTreeNode*> nodes; 00613 nodes.push_back(_root); 00614 for(int i=_maxLevel;i>_minLevel;i--) { 00615 double sep = pow(base,i); 00616 typename std::vector<CoverTreeNode*>::const_iterator it, it2; 00617 //verify separation invariant of cover tree: for each level, 00618 //every point is farther than base^level away 00619 for(it=nodes.begin(); it!=nodes.end(); ++it) { 00620 for(it2=nodes.begin(); it2!=nodes.end(); ++it2) { 00621 double dist=(*it)->distance((*it2)->getPoint()); 00622 if(dist<=sep && dist!=0.0) { 00623 std::cout << "Level " << i << " Separation invariant failed.\n"; 00624 return false; 00625 } 00626 } 00627 } 00628 std::vector<CoverTreeNode*> allChildren; 00629 for(it=nodes.begin(); it!=nodes.end(); ++it) { 00630 std::vector<CoverTreeNode*> children = (*it)->getChildren(i); 00631 //verify covering tree invariant: the children of node n at level 00632 //i are no further than base^i away 00633 for(it2=children.begin(); it2!=children.end(); ++it2) { 00634 double dist = (*it2)->distance((*it)->getPoint()); 00635 if(dist>sep) { 00636 std::cout << "Level" << i << " covering tree invariant failed.n"; 00637 return false; 00638 } 00639 } 00640 allChildren.insert 00641 (allChildren.end(),children.begin(),children.end()); 00642 } 00643 nodes.insert(nodes.begin(),allChildren.begin(),allChildren.end()); 00644 } 00645 return true; 00646 } 00647 } 00648 #endif // _COVER_TREE_H 00649