kd_tree.h
Go to the documentation of this file.
00001 /* 00002 * Copyright 2003-2005 Daniel F. Savarese 00003 * Copyright 2006-2009 Savarese Software Research Corporation 00004 * 00005 * Licensed under the Apache License, Version 2.0 (the "License"); 00006 * you may not use this file except in compliance with the License. 00007 * You may obtain a copy of the License at 00008 * 00009 * https://www.savarese.com/software/ApacheLicense-2.0 00010 * 00011 * Unless required by applicable law or agreed to in writing, software 00012 * distributed under the License is distributed on an "AS IS" BASIS, 00013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00014 * See the License for the specific language governing permissions and 00015 * limitations under the License. 00016 */ 00017 00023 #ifndef __SSRC_SPATIAL_KDTREE_H 00024 #define __SSRC_SPATIAL_KDTREE_H 00025 00026 #include <ssrc/spatial/detail/kd_tree_range_search_iterator.h> 00027 #include <ssrc/spatial/detail/kd_tree_node.h> 00028 #include <ssrc/spatial/detail/kd_tree_nearest_neighbor.h> 00029 00030 #ifdef LIBSSRCKDTREE_HAVE_BOOST 00031 #include <ssrc/spatial/detail/kd_tree_nearest_neighbors.h> 00032 #endif 00033 00034 #include <ssrc/spatial/rectangle_region.h> 00035 00036 #include <algorithm> 00037 #include <utility> 00038 #include <vector> 00039 00040 __BEGIN_NS_SSRC_SPATIAL 00041 00045 template<typename Tree> 00046 struct kd_tree_traits { 00047 typedef typename Tree::key_type key_type; 00048 typedef typename Tree::mapped_type mapped_type; 00049 typedef typename Tree::value_type value_type; 00050 typedef typename Tree::pointer pointer; 00051 typedef typename Tree::const_pointer const_pointer; 00052 typedef typename Tree::reference reference; 00053 typedef typename Tree::const_reference const_reference; 00054 typedef typename Tree::discriminator_type discriminator_type; 00055 typedef typename Tree::node_type node_type; 00056 typedef typename Tree::iterator iterator; 00057 typedef typename Tree::const_iterator const_iterator; 00058 typedef typename Tree::size_type size_type; 00059 typedef typename key_type::value_type coordinate_type; 00060 typedef Tree tree_type; 00061 00067 static const coordinate_type max_coordinate() { 00068 return detail::coordinate_limits<coordinate_type>::highest(); 00069 } 00070 00076 static const coordinate_type min_coordinate() { 00077 return detail::coordinate_limits<coordinate_type>::lowest(); 00078 } 00079 00081 static const key_type upper_bound; 00082 00084 static const key_type lower_bound; 00085 00087 static const unsigned int dimensions = NS_TR1::tuple_size<key_type>::value; 00088 00089 private: 00090 static key_type init_point(const coordinate_type & value) { 00091 key_type point; 00092 for(unsigned int i=0; i < dimensions; ++i) 00093 point[i] = value; 00094 return point; 00095 } 00096 00097 static const key_type _upper_bound() { 00098 return init_point(max_coordinate()); 00099 } 00100 00101 static const key_type _lower_bound() { 00102 return init_point(min_coordinate()); 00103 } 00104 }; 00105 00106 template<class Tree> 00107 typename kd_tree_traits<Tree>::key_type const 00108 kd_tree_traits<Tree>::upper_bound(kd_tree_traits<Tree>::_upper_bound()); 00109 00110 template<class Tree> 00111 typename kd_tree_traits<Tree>::key_type const 00112 kd_tree_traits<Tree>::lower_bound(kd_tree_traits<Tree>::_lower_bound()); 00113 00117 template<typename Tree> 00118 struct kd_tree_const_traits : public kd_tree_traits<Tree> { 00119 typedef typename Tree::const_pointer pointer; 00120 typedef typename Tree::const_reference reference; 00121 }; 00122 00123 00124 // Note: we store the discriminator in each node to avoid modulo division, 00125 // trading space for time. 00136 template<typename Point, 00137 typename Value, 00138 typename Discriminator = unsigned char, 00139 typename Size = unsigned int> 00140 class kd_tree { 00141 public: 00142 00143 typedef kd_tree_traits<kd_tree> traits; 00144 typedef kd_tree_const_traits<kd_tree> const_traits; 00145 typedef Point key_type; 00146 typedef Value mapped_type; 00147 typedef std::pair<const key_type, mapped_type> value_type; 00148 typedef value_type* pointer; 00149 typedef pointer const const_pointer; 00150 typedef value_type& reference; 00151 typedef const reference const_reference; 00152 typedef Discriminator discriminator_type; 00153 typedef rectangle_region<key_type> default_region_type; 00154 // Is this really what we want--two distinct types as 00155 // opposed to iterator and const iterator? 00156 typedef 00157 detail::kd_tree_range_search_iterator<traits, default_region_type> iterator; 00158 typedef 00159 detail::kd_tree_range_search_iterator<const_traits, default_region_type> 00160 const_iterator; 00161 00162 typedef Size size_type; 00163 00164 typedef detail::kd_tree_node<traits> node_type; 00165 00166 private: 00167 00168 struct node_comparator { 00169 mutable discriminator_type discriminator; 00170 00171 explicit node_comparator() : discriminator(0) { } 00172 00173 bool operator()(const node_type* const & n1, 00174 const node_type* const & n2) const 00175 { 00176 return (n1->point()[discriminator] < n2->point()[discriminator]); 00177 } 00178 }; 00179 00180 node_type *_root; 00181 size_type _size; 00182 iterator _end_iterator; 00183 const_iterator _const_end_iterator; 00184 00185 node_type * 00186 get_node(const key_type & point, node_type ** const parent = 0) const { 00187 discriminator_type discriminator; 00188 node_type *node = _root, *last = 0; 00189 00190 while(node != 0) { 00191 discriminator = node->discriminator; 00192 00193 if(point[discriminator] > node->point()[discriminator]) { 00194 last = node; 00195 node = node->child_high; 00196 } else if(point[discriminator] < node->point()[discriminator]) { 00197 last = node; 00198 node = node->child_low; 00199 } else if(node->point() == point) { 00200 if(parent != 0) 00201 *parent = last; 00202 return node; 00203 } else { 00204 last = node; 00205 node = node->child_high; 00206 } 00207 } 00208 00209 if(parent != 0) 00210 *parent = last; 00211 00212 return 0; 00213 } 00214 00215 node_type * get_minimum_node(node_type * const node, node_type * const p, 00216 const discriminator_type discriminator, 00217 node_type ** const parent) 00218 { 00219 node_type *result; 00220 00221 if(discriminator == node->discriminator) { 00222 if(node->child_low != 0) 00223 return 00224 get_minimum_node(node->child_low, node, 00225 discriminator, parent); 00226 else 00227 result = node; 00228 } else { 00229 node_type *nlow = 0, *nhigh = 0; 00230 node_type *plow, *phigh; 00231 00232 if(node->child_low != 0) 00233 nlow = get_minimum_node(node->child_low, node, 00234 discriminator, &plow); 00235 00236 if(node->child_high != 0) 00237 nhigh = get_minimum_node(node->child_high, node, 00238 discriminator, &phigh); 00239 00240 if(nlow != 0 && nhigh != 0) { 00241 if(nlow->point()[discriminator] < nhigh->point()[discriminator]) { 00242 result = nlow; 00243 *parent = plow; 00244 } else { 00245 result = nhigh; 00246 *parent = phigh; 00247 } 00248 } else if(nlow != 0) { 00249 result = nlow; 00250 *parent = plow; 00251 } else if(nhigh != 0) { 00252 result = nhigh; 00253 *parent = phigh; 00254 } else 00255 result = node; 00256 } 00257 00258 if(result == node) 00259 *parent = p; 00260 else if(node->point()[discriminator] < result->point()[discriminator]) { 00261 result = node; 00262 *parent = p; 00263 } 00264 00265 return result; 00266 } 00267 00268 node_type * recursive_remove_node(node_type * const node) { 00269 discriminator_type discriminator; 00270 node_type *new_root, *parent; 00271 00272 if(node->child_low == 0 && 00273 node->child_high == 0) 00274 return 0; 00275 else 00276 discriminator = node->discriminator; 00277 00278 if(node->child_high == 0) { 00279 node->child_high = node->child_low; 00280 node->child_low = 0; 00281 } 00282 00283 new_root = get_minimum_node(node->child_high, node, 00284 discriminator, &parent); 00285 00286 if(parent->child_low == new_root) 00287 parent->child_low = recursive_remove_node(new_root); 00288 else 00289 parent->child_high = recursive_remove_node(new_root); 00290 00291 new_root->child_low = node->child_low; 00292 new_root->child_high = node->child_high; 00293 new_root->discriminator = node->discriminator; 00294 00295 return new_root; 00296 } 00297 00298 // Splitting up remove in this way allows us to implement 00299 // iterator erase(iterator) properly. 00300 bool remove(node_type * const node, node_type * const parent) { 00301 node_type * const new_root = recursive_remove_node(node); 00302 00303 if(parent == 0) 00304 _root = new_root; 00305 else if(node == parent->child_low) 00306 parent->child_low = new_root; 00307 else 00308 parent->child_high = new_root; 00309 00310 // Must zero children so they are not deleted by ~node_type() 00311 node->child_low = 0; 00312 node->child_high = 0; 00313 00314 --_size; 00315 delete node; 00316 00317 return true; 00318 } 00319 00320 bool add(const key_type & point, const mapped_type & value, 00321 node_type ** const node, node_type *parent, 00322 mapped_type * const replaced = 0) 00323 { 00324 if(parent == 0) { 00325 if(_root != 0) 00326 *node = _root; 00327 else { 00328 _root = *node = new node_type(0, point, value); 00329 ++_size; 00330 return false; 00331 } 00332 } else if(*node == 0) { 00333 discriminator_type discriminator = parent->discriminator; 00334 node_type* & child = 00335 (point[discriminator] >= parent->point()[discriminator] ? 00336 parent->child_high : parent->child_low); 00337 00338 if(++discriminator >= traits::dimensions) 00339 discriminator = 0; 00340 00341 child = *node = new node_type(discriminator, point, value); 00342 00343 ++_size; 00344 return false; 00345 } 00346 00347 if(replaced != 0) 00348 *replaced = (*node)->value(); 00349 00350 (*node)->value() = value; 00351 00352 return true; 00353 } 00354 00355 template<typename container_iterator> 00356 static node_type * optimize(const container_iterator & begin, 00357 const container_iterator & end, 00358 const node_comparator & comparator) 00359 { 00360 node_type *midpoint = 0; 00361 typename container_iterator::difference_type diff; 00362 00363 diff = end - begin; 00364 00365 if(diff > 1) { 00366 discriminator_type discriminator = comparator.discriminator; 00367 container_iterator nth = begin + (diff >> 1); 00368 container_iterator nthprev = nth - 1; 00369 00370 //std::nth_element(begin, nth, end, comparator); 00371 std::stable_sort(begin, end, comparator); 00372 00373 // Ties go in the right subtree. 00374 while(nth > begin && 00375 (*nth)->point()[discriminator] == 00376 (*nthprev)->point()[discriminator]) 00377 { 00378 --nth; 00379 --nthprev; 00380 } 00381 00382 midpoint = *nth; 00383 midpoint->discriminator = discriminator; 00384 00385 if(++discriminator >= traits::dimensions) 00386 discriminator = 0; 00387 00388 comparator.discriminator = discriminator; 00389 00390 // Left subtree 00391 midpoint->child_low = optimize(begin, nth, comparator); 00392 00393 comparator.discriminator = discriminator; 00394 00395 // Right subtree 00396 midpoint->child_high = optimize(nth + 1, end, comparator); 00397 } else if(diff == 1) { 00398 midpoint = *begin; 00399 midpoint->discriminator = comparator.discriminator; 00400 midpoint->child_low = 0; 00401 midpoint->child_high = 0; 00402 } 00403 00404 return midpoint; 00405 } 00406 00407 template<class container> 00408 static void fill_container(container & c, node_type * const node) { 00409 if(node == 0) 00410 return; 00411 c.push_back(node); 00412 fill_container(c, node->child_low); 00413 fill_container(c, node->child_high); 00414 } 00415 00416 public: 00417 00421 explicit kd_tree() : 00422 _root(0), _size(0), _end_iterator(), _const_end_iterator() 00423 { } 00424 00430 kd_tree(const kd_tree & tree) : 00431 _root(0), _size(0), _end_iterator(), _const_end_iterator() 00432 { 00433 for(const_iterator p = tree.begin(); !p.end_of_range(); ++p) 00434 insert(p->first, p->second); 00435 } 00436 00440 virtual ~kd_tree() { delete _root; } 00441 00445 void clear() { 00446 delete _root; 00447 _root = 0; 00448 _size = 0; 00449 } 00450 00458 kd_tree & operator=(const kd_tree & tree) { 00459 clear(); 00460 for(const_iterator p = tree.begin(); !p.end_of_range(); ++p) 00461 insert(p->first, p->second); 00462 return *this; 00463 } 00464 00470 const size_type size() const { 00471 return _size; 00472 } 00473 00479 const size_type max_size() const { 00480 return std::numeric_limits<size_type>::max(); 00481 } 00482 00490 bool empty() const { 00491 return (_root == 0); 00492 } 00493 00501 iterator begin() { 00502 return iterator(default_region_type(traits::lower_bound, traits::upper_bound), _root); 00503 } 00504 00512 const_iterator begin() const { 00513 return const_iterator(default_region_type(traits::lower_bound, traits::upper_bound), _root); 00514 } 00515 00521 iterator & end() { 00522 return _end_iterator; 00523 } 00524 00530 const const_iterator & end() const { 00531 return _const_end_iterator; 00532 } 00533 00546 bool insert(const key_type & point, const mapped_type & value, 00547 mapped_type * const replaced = 0) 00548 { 00549 node_type *parent; 00550 node_type *node = get_node(point, &parent); 00551 00552 return add(point, value, &node, parent, replaced); 00553 } 00554 00566 std::pair<iterator,bool> insert(const value_type & mapping) { 00567 // Ideally, we'd do this all in one step, but that will have 00568 // to wait until we optimize the way we handle iterators. 00569 mapped_type existing; 00570 const bool replaced = insert(mapping.first, mapping.second, &existing); 00571 const iterator value = find(mapping.first); 00572 00573 if(replaced) 00574 value._node->value() = existing; 00575 00576 return std::pair<iterator,bool>(value,!replaced); 00577 } 00578 00579 00589 mapped_type & operator[](const key_type & point) { 00590 node_type *parent; 00591 node_type *node = get_node(point, &parent); 00592 00593 if(node == 0) 00594 add(point, mapped_type(), &node, parent); 00595 00596 return node->value(); 00597 } 00598 00609 bool remove(const key_type & point, mapped_type * const erased = 0) { 00610 node_type *parent; 00611 node_type * const node = get_node(point, &parent); 00612 00613 if(node == 0) 00614 return false; 00615 00616 if(erased != 0) 00617 *erased = node->value(); 00618 00619 return remove(node, parent); 00620 } 00621 00628 size_type erase(const key_type & point) { 00629 return remove(point); 00630 } 00631 00642 iterator erase(iterator pos) { 00643 if(pos.end_of_range()) 00644 return _end_iterator; 00645 00646 node_type *parent; 00647 node_type * const node = get_node(pos->first, &parent); 00648 00649 if(node == 0) 00650 return _end_iterator; 00651 00652 typename iterator::stack_type & stack = pos._stack; 00653 00654 // Pop any children. Tree at parent and above is unchanged. 00655 // Low child is pushed last so check it first. 00656 if(!stack.empty() && node->child_low == stack.top()) { 00657 stack.pop(); 00658 } 00659 if(!stack.empty() && node->child_high == stack.top()) { 00660 stack.pop(); 00661 } 00662 00663 const bool low_child = (parent != 0 && parent->child_low == node); 00664 00665 if(remove(node, parent)) { 00666 if(parent != 0) { 00667 if(low_child && parent->child_low != 0) { 00668 stack.push(parent->child_low); 00669 } else if(!low_child && parent->child_high != 0) { 00670 stack.push(parent->child_high); 00671 } 00672 pos.advance(); 00673 return pos; 00674 } else if(_root != 0) { 00675 stack.push(_root); 00676 pos.advance(); 00677 return pos; 00678 } 00679 } 00680 00681 return _end_iterator; 00682 } 00683 00697 iterator begin(const key_type & lower, const key_type & upper) { 00698 return iterator(default_region_type(lower, upper), _root); 00699 } 00700 00714 const_iterator begin(const key_type & lower, const key_type & upper) const { 00715 return const_iterator(default_region_type(lower, upper), _root); 00716 } 00717 00718 // TODO: Document these. Implement circle_region and sphere_region 00719 // and write unit tests. Move kd_tree_range_search_iterator out of detail. 00720 template<typename Region> 00721 detail::kd_tree_range_search_iterator<traits, Region> 00722 begin(const Region & region) { 00723 return 00724 detail::kd_tree_range_search_iterator<traits, Region>(region, _root); 00725 } 00726 00727 template<typename Region> 00728 detail::kd_tree_range_search_iterator<traits, Region> end() { 00729 return detail::kd_tree_range_search_iterator<traits, Region>(); 00730 } 00731 00732 template<typename Region> 00733 detail::kd_tree_range_search_iterator<const_traits, Region> 00734 begin(const Region & region) const { 00735 return 00736 detail::kd_tree_range_search_iterator<const_traits, Region>(region, _root); 00737 } 00738 00739 template<typename Region> 00740 detail::kd_tree_range_search_iterator<const_traits, Region> end() const { 00741 return detail::kd_tree_range_search_iterator<const_traits, Region>(); 00742 } 00743 00757 bool get(const key_type & point, mapped_type * const value = 0) const { 00758 const node_type * const node = get_node(point); 00759 00760 if(node == 0) 00761 return false; 00762 else if(value != 0) 00763 *value = node->value(); 00764 00765 return true; 00766 } 00767 00778 iterator find(const key_type & point) { 00779 return iterator(default_region_type(point, traits::upper_bound), _root, true); 00780 } 00781 00792 const_iterator find(const key_type & point) const { 00793 return const_iterator(default_region_type(point, traits::upper_bound), 00794 _root, true); 00795 } 00796 00804 void optimize() { 00805 if(empty()) 00806 return; 00807 00808 typedef std::vector<node_type*> container; 00809 container nodes; 00810 00811 nodes.reserve(size()); 00812 fill_container(nodes, _root); 00813 00814 _root = optimize(nodes.begin(), nodes.end(), node_comparator()); 00815 } 00816 00829 friend bool operator==(const kd_tree & tree1, const kd_tree & tree2) { 00830 if(tree1.size() != tree2.size()) 00831 return false; 00832 00833 mapped_type value; 00834 00835 for(const_iterator p = tree2.begin(); !p.end_of_range(); ++p) { 00836 if(!tree1.get(p->first, &value) || value != p->second) 00837 return false; 00838 } 00839 00840 return true; 00841 } 00842 00843 // Experimental functions whose API may change or may become standalone 00844 // functions or functor classes. 00845 00870 iterator find_nearest_neighbor(const key_type & point, 00871 const bool omit_query_point = true) 00872 { 00873 const detail::kd_tree_nearest_neighbor<traits, double> nn; 00874 return iterator(nn.find(_root, point, omit_query_point)); 00875 } 00876 00877 #ifdef LIBSSRCKDTREE_HAVE_BOOST 00878 typedef 00879 typename detail::kd_tree_nearest_neighbors<traits, double>::iterator 00880 knn_iterator; 00881 00900 std::pair<knn_iterator, knn_iterator> 00901 find_nearest_neighbors(const key_type & point, 00902 const unsigned int num_neighbors, 00903 const bool omit_query_point = true) 00904 { 00905 const detail::kd_tree_nearest_neighbors<traits, double> knn; 00906 return knn.find(_root, point, num_neighbors, omit_query_point); 00907 } 00908 #endif 00909 00910 }; 00911 00912 __END_NS_SSRC_SPATIAL 00913 00914 #endif