Savarese Software Research Corporation
kd_tree.h
Go to the documentation of this file.
00001 /*
00002  * Copyright 2003-2005 Daniel F. Savarese
00003  * Copyright 2006-2009 Savarese Software Research Corporation
00004  *
00005  * Licensed under the Apache License, Version 2.0 (the "License");
00006  * you may not use this file except in compliance with the License.
00007  * You may obtain a copy of the License at
00008  *
00009  *     https://www.savarese.com/software/ApacheLicense-2.0
00010  *
00011  * Unless required by applicable law or agreed to in writing, software
00012  * distributed under the License is distributed on an "AS IS" BASIS,
00013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00014  * See the License for the specific language governing permissions and
00015  * limitations under the License.
00016  */
00017 
00023 #ifndef __SSRC_SPATIAL_KDTREE_H
00024 #define __SSRC_SPATIAL_KDTREE_H
00025 
00026 #include <ssrc/spatial/detail/kd_tree_range_search_iterator.h>
00027 #include <ssrc/spatial/detail/kd_tree_node.h>
00028 #include <ssrc/spatial/detail/kd_tree_nearest_neighbor.h>
00029 
00030 #ifdef LIBSSRCKDTREE_HAVE_BOOST
00031 #include <ssrc/spatial/detail/kd_tree_nearest_neighbors.h>
00032 #endif
00033 
00034 #include <ssrc/spatial/rectangle_region.h>
00035 
00036 #include <algorithm>
00037 #include <utility>
00038 #include <vector>
00039 
00040 __BEGIN_NS_SSRC_SPATIAL
00041 
00045 template<typename Tree>
00046 struct kd_tree_traits {
00047   typedef typename Tree::key_type key_type;
00048   typedef typename Tree::mapped_type mapped_type;
00049   typedef typename Tree::value_type value_type;
00050   typedef typename Tree::pointer pointer;
00051   typedef typename Tree::const_pointer const_pointer;
00052   typedef typename Tree::reference reference;
00053   typedef typename Tree::const_reference const_reference;
00054   typedef typename Tree::discriminator_type discriminator_type;
00055   typedef typename Tree::node_type node_type;
00056   typedef typename Tree::iterator iterator;
00057   typedef typename Tree::const_iterator const_iterator;
00058   typedef typename Tree::size_type size_type;
00059   typedef typename key_type::value_type coordinate_type;
00060   typedef Tree tree_type;
00061 
00067   static const coordinate_type max_coordinate() {
00068     return detail::coordinate_limits<coordinate_type>::highest();
00069   }
00070 
00076   static const coordinate_type min_coordinate() {
00077     return detail::coordinate_limits<coordinate_type>::lowest();
00078   }
00079 
00081   static const key_type upper_bound;
00082 
00084   static const key_type lower_bound;
00085 
00087   static const unsigned int dimensions = NS_TR1::tuple_size<key_type>::value;
00088 
00089 private:
00090   static key_type init_point(const coordinate_type & value) {
00091     key_type point;
00092     for(unsigned int i=0; i < dimensions; ++i)
00093       point[i] = value;
00094     return point;
00095   }
00096 
00097   static const key_type _upper_bound() {
00098     return init_point(max_coordinate());
00099   }
00100 
00101   static const key_type _lower_bound() {
00102     return init_point(min_coordinate());
00103   }
00104 };
00105 
00106 template<class Tree>
00107 typename kd_tree_traits<Tree>::key_type const
00108 kd_tree_traits<Tree>::upper_bound(kd_tree_traits<Tree>::_upper_bound());
00109 
00110 template<class Tree>
00111 typename kd_tree_traits<Tree>::key_type const
00112 kd_tree_traits<Tree>::lower_bound(kd_tree_traits<Tree>::_lower_bound());
00113 
00117 template<typename Tree>
00118 struct kd_tree_const_traits : public kd_tree_traits<Tree> {
00119   typedef typename Tree::const_pointer pointer;
00120   typedef typename Tree::const_reference reference;
00121 };
00122 
00123 
00124 // Note: we store the discriminator in each node to avoid modulo division,
00125 // trading space for time.
00136 template<typename Point,
00137          typename Value,
00138          typename Discriminator = unsigned char,
00139          typename Size = unsigned int>
00140 class kd_tree {
00141 public:
00142 
00143   typedef kd_tree_traits<kd_tree> traits;
00144   typedef kd_tree_const_traits<kd_tree> const_traits;
00145   typedef Point key_type;
00146   typedef Value mapped_type;
00147   typedef std::pair<const key_type, mapped_type> value_type;
00148   typedef value_type* pointer;
00149   typedef pointer const const_pointer;
00150   typedef value_type& reference;
00151   typedef const reference const_reference;
00152   typedef Discriminator discriminator_type;
00153   typedef rectangle_region<key_type> default_region_type;
00154   // Is this really what we want--two distinct types as
00155   // opposed to iterator and const iterator?
00156   typedef
00157   detail::kd_tree_range_search_iterator<traits, default_region_type> iterator;
00158   typedef
00159   detail::kd_tree_range_search_iterator<const_traits, default_region_type>
00160   const_iterator;
00161 
00162   typedef Size size_type;
00163 
00164   typedef detail::kd_tree_node<traits> node_type;
00165 
00166 private:
00167 
00168   struct node_comparator {
00169     mutable discriminator_type discriminator;
00170 
00171     explicit node_comparator() : discriminator(0) { }
00172 
00173     bool operator()(const node_type* const & n1,
00174                     const node_type* const & n2) const
00175     {
00176       return (n1->point()[discriminator] < n2->point()[discriminator]);
00177     }
00178   };
00179 
00180   node_type *_root;
00181   size_type _size;
00182   iterator _end_iterator;
00183   const_iterator _const_end_iterator;
00184 
00185   node_type *
00186   get_node(const key_type & point, node_type ** const parent = 0) const {
00187     discriminator_type discriminator;
00188     node_type *node = _root, *last = 0;
00189 
00190     while(node != 0) {
00191       discriminator = node->discriminator;
00192 
00193       if(point[discriminator] > node->point()[discriminator]) {
00194         last = node;
00195         node = node->child_high;
00196       } else if(point[discriminator] < node->point()[discriminator]) {
00197         last = node;
00198         node = node->child_low;
00199       } else if(node->point() == point) {
00200         if(parent != 0)
00201           *parent = last;
00202         return node;
00203       } else {
00204         last = node;
00205         node = node->child_high;
00206       }
00207     }
00208 
00209     if(parent != 0)
00210       *parent = last;
00211 
00212     return 0;
00213   }
00214 
00215   node_type * get_minimum_node(node_type * const node, node_type * const p,
00216                                const discriminator_type discriminator,
00217                                node_type ** const parent)
00218   {
00219     node_type *result;
00220 
00221     if(discriminator == node->discriminator) {
00222       if(node->child_low != 0)
00223         return
00224           get_minimum_node(node->child_low, node,
00225                            discriminator, parent);
00226       else
00227         result = node;
00228     } else {
00229       node_type *nlow = 0, *nhigh = 0;
00230       node_type *plow, *phigh;
00231 
00232       if(node->child_low != 0)
00233         nlow = get_minimum_node(node->child_low, node,
00234                                 discriminator, &plow);
00235 
00236       if(node->child_high != 0)
00237         nhigh = get_minimum_node(node->child_high, node,
00238                                  discriminator, &phigh);
00239 
00240       if(nlow != 0 && nhigh != 0) {
00241         if(nlow->point()[discriminator] < nhigh->point()[discriminator]) {
00242           result  = nlow;
00243           *parent = plow;
00244         } else {
00245           result  = nhigh;
00246           *parent = phigh;
00247         }
00248       } else if(nlow != 0) {
00249         result  = nlow;
00250         *parent = plow;
00251       } else if(nhigh != 0) {
00252         result  = nhigh;
00253         *parent = phigh;
00254       } else
00255         result  = node;
00256     }
00257 
00258     if(result == node)
00259       *parent = p;
00260     else if(node->point()[discriminator] < result->point()[discriminator]) {
00261       result  = node;
00262       *parent = p;
00263     }
00264 
00265     return result;
00266   }
00267 
00268   node_type * recursive_remove_node(node_type * const node) {
00269     discriminator_type discriminator;
00270     node_type *new_root, *parent;
00271 
00272     if(node->child_low == 0 &&
00273        node->child_high == 0)
00274       return 0;
00275     else
00276       discriminator = node->discriminator;
00277 
00278     if(node->child_high == 0) {
00279       node->child_high = node->child_low;
00280       node->child_low  = 0;
00281     }
00282 
00283     new_root = get_minimum_node(node->child_high, node,
00284                                 discriminator, &parent);
00285 
00286     if(parent->child_low == new_root)
00287       parent->child_low = recursive_remove_node(new_root);
00288     else
00289       parent->child_high = recursive_remove_node(new_root);
00290 
00291     new_root->child_low  = node->child_low;
00292     new_root->child_high = node->child_high;
00293     new_root->discriminator = node->discriminator;
00294 
00295     return new_root;
00296   }
00297 
00298   // Splitting up remove in this way allows us to implement
00299   // iterator erase(iterator) properly.
00300   bool remove(node_type * const node, node_type * const parent) {
00301     node_type * const new_root = recursive_remove_node(node);
00302 
00303     if(parent == 0)
00304       _root = new_root;
00305     else if(node == parent->child_low)
00306       parent->child_low = new_root;
00307     else
00308       parent->child_high = new_root;
00309 
00310     // Must zero children so they are not deleted by ~node_type()
00311     node->child_low  = 0; 
00312     node->child_high = 0;
00313 
00314     --_size;
00315     delete node;
00316 
00317     return true;
00318   }
00319 
00320   bool add(const key_type & point, const mapped_type & value,
00321             node_type ** const node, node_type *parent,
00322             mapped_type * const replaced = 0)
00323   {
00324     if(parent == 0) {
00325       if(_root != 0)
00326         *node = _root;
00327       else {
00328         _root = *node = new node_type(0, point, value);
00329         ++_size;
00330         return false;
00331       }
00332     } else if(*node == 0) {
00333       discriminator_type discriminator = parent->discriminator;
00334       node_type* & child =
00335         (point[discriminator] >= parent->point()[discriminator] ?
00336          parent->child_high : parent->child_low);
00337 
00338       if(++discriminator >= traits::dimensions)
00339         discriminator = 0;
00340 
00341       child = *node = new node_type(discriminator, point, value);
00342 
00343       ++_size;
00344       return false;
00345     }
00346 
00347     if(replaced != 0)
00348       *replaced = (*node)->value();
00349 
00350     (*node)->value() = value;
00351 
00352     return true;
00353   }
00354 
00355   template<typename container_iterator>
00356   static node_type * optimize(const container_iterator & begin,
00357                               const container_iterator & end,
00358                               const node_comparator & comparator)
00359   {
00360     node_type *midpoint = 0;
00361     typename container_iterator::difference_type diff;
00362 
00363     diff = end - begin;
00364 
00365     if(diff > 1) {
00366       discriminator_type discriminator = comparator.discriminator;
00367       container_iterator nth = begin + (diff >> 1);
00368       container_iterator nthprev = nth - 1;
00369 
00370       //std::nth_element(begin, nth, end, comparator);
00371       std::stable_sort(begin, end, comparator);
00372 
00373       // Ties go in the right subtree.
00374       while(nth > begin &&
00375             (*nth)->point()[discriminator] == 
00376             (*nthprev)->point()[discriminator])
00377         {
00378           --nth;
00379           --nthprev;
00380         }
00381 
00382       midpoint = *nth;
00383       midpoint->discriminator = discriminator;
00384 
00385       if(++discriminator >= traits::dimensions)
00386         discriminator = 0;
00387 
00388       comparator.discriminator = discriminator;
00389 
00390       // Left subtree
00391       midpoint->child_low = optimize(begin, nth, comparator);
00392 
00393       comparator.discriminator = discriminator;
00394 
00395       // Right subtree
00396       midpoint->child_high = optimize(nth + 1, end, comparator);
00397     } else if(diff == 1) {
00398       midpoint = *begin;
00399       midpoint->discriminator = comparator.discriminator;
00400       midpoint->child_low = 0;
00401       midpoint->child_high = 0;
00402     }
00403 
00404     return midpoint;
00405   }
00406 
00407   template<class container>
00408   static void fill_container(container & c, node_type * const node) {
00409     if(node == 0)
00410       return;
00411     c.push_back(node);
00412     fill_container(c, node->child_low);
00413     fill_container(c, node->child_high);
00414   }
00415 
00416 public:
00417 
00421   explicit kd_tree() :
00422     _root(0), _size(0), _end_iterator(), _const_end_iterator()
00423   { }
00424 
00430   kd_tree(const kd_tree & tree) :
00431     _root(0), _size(0), _end_iterator(), _const_end_iterator()
00432   {
00433     for(const_iterator p = tree.begin(); !p.end_of_range(); ++p)
00434       insert(p->first, p->second);
00435   }
00436 
00440   virtual ~kd_tree() { delete _root; }
00441 
00445   void clear() {
00446     delete _root;
00447     _root = 0;
00448     _size = 0;
00449   }
00450 
00458   kd_tree & operator=(const kd_tree & tree) {
00459     clear();
00460     for(const_iterator p = tree.begin(); !p.end_of_range(); ++p)
00461       insert(p->first, p->second);
00462     return *this;
00463   }
00464 
00470   const size_type size() const {
00471     return _size;
00472   }
00473 
00479   const size_type max_size() const {
00480     return std::numeric_limits<size_type>::max();
00481   }
00482 
00490   bool empty() const {
00491     return (_root == 0);
00492   }
00493 
00501   iterator begin() {
00502     return iterator(default_region_type(traits::lower_bound, traits::upper_bound), _root);
00503   }
00504 
00512   const_iterator begin() const {
00513     return const_iterator(default_region_type(traits::lower_bound, traits::upper_bound), _root);
00514   }
00515 
00521   iterator & end() {
00522     return _end_iterator;
00523   }
00524 
00530   const const_iterator & end() const {
00531     return _const_end_iterator;
00532   }
00533 
00546   bool insert(const key_type & point, const mapped_type & value,
00547               mapped_type * const replaced = 0)
00548   {
00549     node_type *parent;
00550     node_type *node = get_node(point, &parent);
00551 
00552     return add(point, value, &node, parent, replaced);
00553   }
00554 
00566   std::pair<iterator,bool> insert(const value_type & mapping) {
00567     // Ideally, we'd do this all in one step, but that will have
00568     // to wait until we optimize the way we handle iterators.
00569     mapped_type existing;
00570     const bool replaced  = insert(mapping.first, mapping.second, &existing);
00571     const iterator value = find(mapping.first);
00572 
00573     if(replaced)
00574       value._node->value() = existing;
00575 
00576     return std::pair<iterator,bool>(value,!replaced);
00577   }
00578 
00579 
00589   mapped_type & operator[](const key_type & point) {
00590     node_type *parent;
00591     node_type *node = get_node(point, &parent);
00592 
00593     if(node == 0)
00594       add(point, mapped_type(), &node, parent);
00595 
00596     return node->value();
00597   }
00598 
00609   bool remove(const key_type & point, mapped_type * const erased = 0) {
00610     node_type *parent;
00611     node_type * const node = get_node(point, &parent);
00612 
00613     if(node == 0)
00614       return false;
00615 
00616     if(erased != 0)
00617       *erased = node->value();
00618 
00619     return remove(node, parent);
00620   }
00621 
00628   size_type erase(const key_type & point) {
00629     return remove(point);
00630   }
00631 
00642   iterator erase(iterator pos) {
00643     if(pos.end_of_range())
00644       return _end_iterator;
00645 
00646     node_type *parent;
00647     node_type * const node = get_node(pos->first, &parent);
00648     
00649     if(node == 0)
00650       return _end_iterator;
00651 
00652     typename iterator::stack_type & stack = pos._stack;
00653 
00654     // Pop any children.  Tree at parent and above is unchanged.
00655     // Low child is pushed last so check it first.
00656     if(!stack.empty() && node->child_low == stack.top()) {
00657       stack.pop();
00658     }
00659     if(!stack.empty() && node->child_high == stack.top()) {
00660       stack.pop();
00661     }
00662 
00663     const bool low_child = (parent != 0 && parent->child_low == node);
00664 
00665     if(remove(node, parent)) {
00666       if(parent != 0) {
00667         if(low_child && parent->child_low != 0) {
00668           stack.push(parent->child_low);
00669         } else if(!low_child && parent->child_high != 0) {
00670           stack.push(parent->child_high);
00671         }
00672         pos.advance();
00673         return pos;
00674       } else if(_root != 0) {
00675         stack.push(_root);
00676         pos.advance();
00677         return pos;
00678       }
00679     }
00680 
00681     return _end_iterator;
00682   }
00683 
00697   iterator begin(const key_type & lower, const key_type & upper) {
00698     return iterator(default_region_type(lower, upper), _root);
00699   }
00700 
00714   const_iterator begin(const key_type & lower, const key_type & upper) const {
00715     return const_iterator(default_region_type(lower, upper), _root);
00716   }
00717 
00718   // TODO: Document these.  Implement circle_region and sphere_region
00719   // and write unit tests.  Move kd_tree_range_search_iterator out of detail.
00720   template<typename Region>
00721   detail::kd_tree_range_search_iterator<traits, Region>
00722   begin(const Region & region) {
00723     return
00724       detail::kd_tree_range_search_iterator<traits, Region>(region, _root);
00725   }
00726 
00727   template<typename Region>
00728   detail::kd_tree_range_search_iterator<traits, Region> end() {
00729     return detail::kd_tree_range_search_iterator<traits, Region>();
00730   }
00731 
00732   template<typename Region>
00733   detail::kd_tree_range_search_iterator<const_traits, Region>
00734   begin(const Region & region) const {
00735     return
00736       detail::kd_tree_range_search_iterator<const_traits, Region>(region, _root);
00737   }
00738 
00739   template<typename Region>
00740   detail::kd_tree_range_search_iterator<const_traits, Region> end() const {
00741     return detail::kd_tree_range_search_iterator<const_traits, Region>();
00742   }
00743 
00757   bool get(const key_type & point, mapped_type * const value = 0) const {
00758     const node_type * const node = get_node(point);
00759 
00760     if(node == 0)
00761       return false;
00762     else if(value != 0)
00763       *value = node->value();
00764 
00765     return true;
00766   }
00767 
00778   iterator find(const key_type & point) {
00779     return iterator(default_region_type(point, traits::upper_bound), _root, true);
00780   }
00781 
00792   const_iterator find(const key_type & point) const {
00793     return const_iterator(default_region_type(point, traits::upper_bound),
00794                           _root, true);
00795   }
00796 
00804   void optimize() {
00805     if(empty())
00806       return;
00807 
00808     typedef std::vector<node_type*> container;
00809     container nodes;
00810 
00811     nodes.reserve(size());
00812     fill_container(nodes, _root);
00813 
00814     _root = optimize(nodes.begin(), nodes.end(), node_comparator());
00815   }
00816 
00829   friend bool operator==(const kd_tree & tree1, const kd_tree & tree2) {
00830     if(tree1.size() != tree2.size())
00831       return false;
00832 
00833     mapped_type value;
00834 
00835     for(const_iterator p = tree2.begin(); !p.end_of_range(); ++p) {
00836       if(!tree1.get(p->first, &value) || value != p->second)
00837         return false;
00838     }
00839 
00840     return true;
00841   }
00842 
00843   // Experimental functions whose API may change or may become standalone
00844   // functions or functor classes.
00845 
00870   iterator find_nearest_neighbor(const key_type & point,
00871                                  const bool omit_query_point = true)
00872   {
00873     const detail::kd_tree_nearest_neighbor<traits, double> nn;
00874     return iterator(nn.find(_root, point, omit_query_point));
00875   }
00876 
00877 #ifdef LIBSSRCKDTREE_HAVE_BOOST
00878   typedef
00879   typename detail::kd_tree_nearest_neighbors<traits, double>::iterator
00880   knn_iterator;
00881 
00900   std::pair<knn_iterator, knn_iterator>
00901   find_nearest_neighbors(const key_type & point,
00902                          const unsigned int num_neighbors,
00903                          const bool omit_query_point = true)
00904   {
00905     const detail::kd_tree_nearest_neighbors<traits, double> knn;
00906     return knn.find(_root, point, num_neighbors, omit_query_point);
00907   }
00908 #endif
00909 
00910 };
00911 
00912 __END_NS_SSRC_SPATIAL
00913 
00914 #endif

Savarese Software Research Corporation
Copyright © 2003-2005 Daniel F. Savarese.
Copyright © 2006-2009 Savarese Software Research Corporation.