Savarese Software Research
Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members

KDTree.h

Go to the documentation of this file.
00001 /*
00002  * $Id: KDTree.h 5871 2005-10-28 02:10:33Z dfs $
00003  *
00004  * Copyright 2003-2005 Daniel F. Savarese
00005  * Copyright 2005 Savarese Software Research
00006  *
00007  * Licensed under the Apache License, Version 2.0 (the "License");
00008  * you may not use this file except in compliance with the License.
00009  * You may obtain a copy of the License at
00010  *
00011  *     https://www.savarese.com/software/ApacheLicense-2.0
00012  *
00013  * Unless required by applicable law or agreed to in writing, software
00014  * distributed under the License is distributed on an "AS IS" BASIS,
00015  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00016  * See the License for the specific language governing permissions and
00017  * limitations under the License.
00018  */
00019 
00025 #ifndef __SAVA_SPATIAL_KDTREE_H
00026 #define __SAVA_SPATIAL_KDTREE_H
00027 
00028 #include <algorithm>
00029 #include <utility>
00030 #include <vector>
00031 
00032 #include <libsava/spatial/detail/KDTree.h>
00033 #include <libsava/spatial/Point.h>
00034 
00035 __BEGIN_PACKAGE_SAVA_SPATIAL
00036 
00040 template<typename Tree>
00041 struct KDTreeTraits {
00042   typedef typename Tree::key_type key_type;
00043   typedef typename Tree::mapped_type mapped_type;
00044   typedef typename Tree::value_type value_type;
00045   typedef typename Tree::pointer pointer;
00046   typedef typename Tree::const_pointer const_pointer;
00047   typedef typename Tree::reference reference;
00048   typedef typename Tree::const_reference const_reference;
00049   typedef typename Tree::discriminator_type discriminator_type;
00050   typedef typename Tree::node_type node_type;
00051   typedef typename node_type::child_type child_type;
00052   typedef typename Tree::iterator iterator;
00053   typedef typename Tree::const_iterator const_iterator;
00054   typedef typename Tree::point_traits point_traits;
00055   typedef Tree tree_type;
00056 
00057   static inline const unsigned int dimensions() {
00058     return Tree::dimensions();
00059   }
00060 };
00061 
00062 
00068 template<typename Tree>
00069 struct KDTreeConstTraits : public KDTreeTraits<Tree> {
00070   typedef typename Tree::const_pointer pointer;
00071   typedef typename Tree::const_reference reference;
00072 };
00073 
00074 
00075 // Note: we store the discriminator in each node to avoid modulo division,
00076 // trading space for time.
00089 template <typename Point,
00090           typename Value,
00091           unsigned int Dimensions = 2,
00092           typename Discriminator = unsigned char,
00093           typename Size = unsigned int>
00094 class KDTree {
00095 
00096 public:
00097 
00098   typedef KDTreeTraits<KDTree> traits;
00099   typedef KDTreeConstTraits<KDTree> const_traits;
00100   typedef Point key_type;
00101   typedef Value mapped_type;
00102   typedef std::pair<const key_type, mapped_type> value_type;
00103   typedef value_type* pointer;
00104   typedef pointer const const_pointer;
00105   typedef value_type& reference;
00106   typedef const reference const_reference;
00107   typedef Discriminator discriminator_type;
00108   typedef detail::KDTreeRangeSearchIterator<traits> iterator;
00109   typedef detail::KDTreeRangeSearchIterator<const_traits> const_iterator;
00110   typedef PointTraits<key_type> point_traits;
00111 
00112   typedef Size size_type;
00113 
00114   static inline const unsigned int dimensions() {
00115     return Dimensions;
00116   }
00117 
00118   typedef detail::KDTreeNode<traits> node_type;
00119   typedef typename node_type::child_type child_type;
00120 
00121 private:
00122 
00123   struct NodeComparator {
00124     discriminator_type discriminator;
00125 
00126     explicit NodeComparator() : discriminator(0) { }
00127 
00128     bool operator()(const node_type* const & n1,
00129                     const node_type* const & n2) const
00130     {
00131       return (n1->point()[discriminator] < n2->point()[discriminator]);
00132     }
00133   };
00134 
00135   node_type *_root;
00136   size_type _size;
00137   iterator _endIterator;
00138   const_iterator _constEndIterator;
00139 
00140   node_type * getNode(const key_type & point, node_type **parent = 0) const {
00141     discriminator_type discriminator;
00142     child_type child;
00143     node_type *node = _root, *last = 0;
00144 
00145     while(node != 0) {
00146       discriminator = node->discriminator();
00147 
00148       if(point[discriminator] > node->point()[discriminator])
00149         child = node_type::ChildHigh;
00150       else if(point[discriminator] < node->point()[discriminator])
00151         child = node_type::ChildLow;
00152       else if(node->point() == point) {
00153         if(parent != 0)
00154           *parent = last;
00155         return node;
00156       } else
00157         child = node_type::ChildHigh;
00158 
00159       last = node;
00160       node = node->child(child);
00161     }
00162 
00163     if(parent != 0)
00164       *parent = last;
00165 
00166     return 0;
00167   }
00168 
00169   node_type * getMinimumNode(node_type *node, node_type *p,
00170                              const discriminator_type discriminator,
00171                              node_type **parent)
00172   {
00173     node_type *result;
00174 
00175     if(discriminator == node->discriminator()) {
00176       if(node->child(node_type::ChildLow) != 0)
00177         return
00178           getMinimumNode(node->child(node_type::ChildLow), node,
00179                               discriminator, parent);
00180       else
00181         result = node;
00182     } else {
00183       node_type *nlow = 0, *nhigh = 0;
00184       node_type *plow, *phigh;
00185 
00186       if(node->child(node_type::ChildLow) != 0)
00187         nlow =
00188           getMinimumNode(node->child(node_type::ChildLow), node,
00189                               discriminator, &plow);
00190 
00191       if(node->child(node_type::ChildHigh) != 0)
00192         nhigh =
00193           getMinimumNode(node->child(node_type::ChildHigh), node,
00194                               discriminator, &phigh);
00195 
00196       if(nlow != 0 && nhigh != 0) {
00197         if(nlow->point()[discriminator] < nhigh->point()[discriminator]) {
00198           result  = nlow;
00199           *parent = plow;
00200         } else {
00201           result  = nhigh;
00202           *parent = phigh;
00203         }
00204       } else if(nlow != 0) {
00205         result  = nlow;
00206         *parent = plow;
00207       } else if(nhigh != 0) {
00208         result  = nhigh;
00209         *parent = phigh;
00210       } else
00211         result  = node;
00212     }
00213 
00214     if(result == node)
00215       *parent = p;
00216     else if(node->point()[discriminator] < result->point()[discriminator]) {
00217       result  = node;
00218       *parent = p;
00219     }
00220 
00221     return result;
00222   }
00223 
00224 
00225   node_type * recursiveRemoveNode(node_type *node) {
00226     discriminator_type discriminator;
00227     node_type *newRoot, *parent;
00228 
00229     if(node->child(node_type::ChildLow) == 0 &&
00230        node->child(node_type::ChildHigh) == 0)
00231       return 0;
00232     else
00233       discriminator = node->discriminator();
00234 
00235     if(node->child(node_type::ChildHigh) == 0) {
00236       node->child(node_type::ChildHigh) = node->child(node_type::ChildLow);
00237       node->child(node_type::ChildLow)  = 0;
00238     }
00239 
00240     newRoot =
00241       getMinimumNode(node->child(node_type::ChildHigh), node,
00242                      discriminator, &parent);
00243 
00244     child_type child = (parent->child(node_type::ChildLow) == newRoot ?
00245                         node_type::ChildLow : node_type::ChildHigh);
00246     parent->child(child) = recursiveRemoveNode(newRoot);
00247 
00248     newRoot->child(node_type::ChildLow)  = node->child(node_type::ChildLow);
00249     newRoot->child(node_type::ChildHigh) = node->child(node_type::ChildHigh);
00250     newRoot->discriminator() = node->discriminator();
00251 
00252     return newRoot;
00253   }
00254 
00255 
00256   bool add(const key_type & point, const mapped_type & value,
00257             node_type **node, node_type *parent,
00258             mapped_type *replaced = 0)
00259   {
00260     if(parent == 0) {
00261       if(_root != 0)
00262         *node = _root;
00263       else {
00264         _root = *node = new node_type(0, point, value);
00265         ++_size;
00266         return false;
00267       }
00268     } else if(*node == 0) {
00269       discriminator_type discriminator;
00270       child_type child;
00271 
00272       discriminator = parent->discriminator();
00273       child = (point[discriminator] >= parent->point()[discriminator] ?
00274                node_type::ChildHigh : node_type::ChildLow);
00275 
00276       if(++discriminator >= dimensions())
00277         discriminator = 0;
00278 
00279       parent->child(child) = *node =
00280         new node_type(discriminator, point, value);
00281 
00282       ++_size;
00283       return false;
00284     }
00285 
00286     if(replaced != 0)
00287       *replaced = (*node)->value();
00288 
00289     (*node)->value() = value;
00290 
00291     return true;
00292   }
00293 
00294   template<template<typename> class Container>
00295   static inline
00296   node_type * optimize(typename Container<node_type*>::iterator begin,
00297                        typename Container<node_type*>::iterator end,
00298                        NodeComparator & comparator)
00299   {
00300     node_type *midpoint = 0;
00301     typename Container<node_type*>::iterator::difference_type diff;
00302 
00303     diff = end - begin;
00304 
00305     if(diff > 1) {
00306       discriminator_type discriminator = comparator.discriminator;
00307       typename Container<node_type*>::iterator nth = begin + (diff >> 1);
00308       typename Container<node_type*>::iterator nthprev = nth - 1;
00309 
00310       //nth_element(begin, nth, end, comparator);
00311       stable_sort(begin, end, comparator);
00312 
00313       // Ties go in the right subtree.
00314       while(nth > begin &&
00315             (*nth)->point()[discriminator] == 
00316             (*nthprev)->point()[discriminator])
00317         {
00318           --nth;
00319           --nthprev;
00320         }
00321 
00322       midpoint = *nth;
00323       midpoint->discriminator() = discriminator;
00324 
00325       if(++discriminator >= dimensions())
00326         discriminator = 0;
00327 
00328       comparator.discriminator = discriminator;
00329 
00330       // Left subtree
00331       midpoint->child(node_type::ChildLow) =
00332         optimize<Container>(begin, nth, comparator);
00333 
00334       comparator.discriminator = discriminator;
00335 
00336       // Right subtree
00337       midpoint->child(node_type::ChildHigh) =
00338         optimize<Container>(nth + 1, end, comparator);
00339     } else if(diff == 1) {
00340       midpoint = *begin;
00341       midpoint->discriminator() = comparator.discriminator;
00342       midpoint->child(node_type::ChildLow) = 0;
00343       midpoint->child(node_type::ChildHigh) = 0;
00344     }
00345 
00346     return midpoint;
00347   }
00348 
00349   template<template<typename> class Container>
00350   static inline void fillContainer(Container<node_type*> & c, node_type *node)
00351   {
00352     if(node == 0)
00353       return;
00354     c.push_back(node);
00355     fillContainer(c, node->child(node_type::ChildLow));
00356     fillContainer(c, node->child(node_type::ChildHigh));
00357   }
00358 
00359   static inline 
00360   void initPoint(key_type & point,
00361                  const typename point_traits::coordinate_type & value)
00362   {
00363     for(unsigned int i=0; i < point_traits::dimensions(); ++i)
00364       point[i] = value;
00365   }
00366   
00367   static inline const key_type upperBound() {
00368     key_type bound;
00369     initPoint(bound, point_traits::max_coordinate());
00370     return bound;
00371   }
00372 
00373   static inline const key_type lowerBound() {
00374     key_type bound;
00375     initPoint(bound, point_traits::min_coordinate());
00376     return bound;
00377   }
00378 
00379 public:
00380 
00381   // TODO: need assignment operator.
00382 
00386   explicit KDTree() : _root(0), _size(0), _endIterator(), _constEndIterator()
00387   { }
00388 
00394   KDTree(const KDTree & tree) :
00395     _root(0), _size(0), _endIterator(), _constEndIterator()
00396   {
00397     for(const_iterator p = tree.begin(); !p.endOfRange(); ++p)
00398       insert(p->first, p->second);
00399   }
00400 
00404   virtual ~KDTree() { delete _root; }
00405 
00409   void clear() {
00410     delete _root;
00411     _root = 0;
00412     _size = 0;
00413   }
00414 
00420   const size_type size() const {
00421     return _size;
00422   }
00423 
00429   const size_type max_size() const {
00430     return std::numeric_limits<size_type>::max();
00431   }
00432 
00440   bool empty() const {
00441     return (_root == 0);
00442   }
00443 
00451   iterator begin() {
00452     return iterator(lowerBound(), upperBound(), _root);
00453   }
00454 
00462   const_iterator begin() const {
00463     return const_iterator(lowerBound(), upperBound(), _root);
00464   }
00465 
00471   iterator & end() {
00472     return _endIterator;
00473   }
00474 
00480   const const_iterator & end() const {
00481     return _constEndIterator;
00482   }
00483 
00496   bool insert(const key_type & point, const mapped_type & value,
00497               mapped_type *replaced = 0)
00498   {
00499     node_type *parent;
00500     node_type *node = getNode(point, &parent);
00501 
00502     return add(point, value, &node, parent, replaced);
00503   }
00504 
00516   std::pair<iterator,bool> insert(const value_type & mapping) {
00517     // Ideally, we'd do this all in one step, but that will have
00518     // to wait until we optimize the way we handle iterators.
00519     bool replaced;
00520     mapped_type existing;
00521     iterator value;
00522 
00523     replaced = insert(mapping.first, mapping.second, &existing);
00524     value = find(mapping.first);
00525 
00526     if(replaced)
00527       value._node->value() = existing;
00528 
00529     return std::pair<iterator,bool>(value,!replaced);
00530   }
00531 
00532 
00542   mapped_type & operator[](const key_type & point) {
00543     node_type *parent;
00544     node_type *node = getNode(point, &parent);
00545 
00546     if(node == 0)
00547       add(point, mapped_type(), &node, parent);
00548 
00549     return node->value();
00550   }
00551 
00562   bool remove(const key_type & point, mapped_type *erased = 0) {
00563     node_type *parent;
00564     node_type *node = getNode(point, &parent);
00565     node_type *child;
00566 
00567     if(node == 0)
00568       return false;
00569 
00570     if(erased != 0)
00571       *erased = node->value();
00572 
00573     child = node;
00574     node  = recursiveRemoveNode(child);
00575 
00576     if(parent == 0)
00577       _root = node;
00578     else if(child == parent->child(node_type::ChildLow))
00579       parent->child(node_type::ChildLow) = node;
00580     else
00581       parent->child(node_type::ChildHigh) = node;
00582 
00583     // Must zero children so they are not deleted by ~node_type()
00584     child->child(node_type::ChildLow)  = 0; 
00585     child->child(node_type::ChildHigh) = 0;
00586 
00587     --_size;
00588     delete child;
00589 
00590     return true;
00591   }
00592 
00599   size_type erase(const key_type & point) {
00600     return remove(point);
00601   }
00602 
00608   void erase(iterator pos) {
00609     remove(pos->first);
00610   }
00611 
00625   iterator begin(const key_type & lower, const key_type & upper) {
00626     return iterator(lower, upper, _root);
00627   }
00628 
00642   const_iterator begin(const key_type & lower, const key_type & upper) const {
00643     return const_iterator(lower, upper, _root);
00644   }
00645 
00659   bool get(const key_type & point, mapped_type *value = 0) const {
00660     node_type *node = getNode(point);
00661 
00662     if(node == 0)
00663       return false;
00664     else if(value != 0)
00665       *value = node->value();
00666 
00667     return true;
00668   }
00669 
00680   iterator find(const key_type & point) {
00681     return iterator(point, upperBound(), _root, true);
00682   }
00683 
00694   const_iterator find(const key_type & point) const {
00695     return const_iterator(point, upperBound(), _root, true);
00696   }
00697 
00698   // Balances the tree.  Very expensive!
00706   void optimize() {
00707     if(empty())
00708       return;
00709 
00710     typedef std::vector<node_type*> container;
00711     container nodes;
00712 
00713     nodes.reserve(size());
00714     fillContainer<std::vector>(nodes, _root);
00715 
00716     NodeComparator comparator;
00717     _root =
00718       optimize<std::vector>(nodes.begin(), nodes.end(), comparator);
00719   }
00720 
00721 };
00722 
00723 
00724 __END_PACKAGE_SAVA_SPATIAL
00725 
00726 #endif
00727 

Savarese Software Research
Copyright © 2003-2005 Savarese Software Research and Daniel F. Savarese. All rights reserved.