00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00025 #ifndef __SAVA_SPATIAL_KDTREE_H
00026 #define __SAVA_SPATIAL_KDTREE_H
00027
00028 #include <algorithm>
00029 #include <utility>
00030 #include <vector>
00031
00032 #include <libsava/spatial/detail/KDTree.h>
00033 #include <libsava/spatial/Point.h>
00034
00035 __BEGIN_PACKAGE_SAVA_SPATIAL
00036
00040 template<typename Tree>
00041 struct KDTreeTraits {
00042 typedef typename Tree::key_type key_type;
00043 typedef typename Tree::mapped_type mapped_type;
00044 typedef typename Tree::value_type value_type;
00045 typedef typename Tree::pointer pointer;
00046 typedef typename Tree::const_pointer const_pointer;
00047 typedef typename Tree::reference reference;
00048 typedef typename Tree::const_reference const_reference;
00049 typedef typename Tree::discriminator_type discriminator_type;
00050 typedef typename Tree::node_type node_type;
00051 typedef typename node_type::child_type child_type;
00052 typedef typename Tree::iterator iterator;
00053 typedef typename Tree::const_iterator const_iterator;
00054 typedef typename Tree::point_traits point_traits;
00055 typedef Tree tree_type;
00056
00057 static inline const unsigned int dimensions() {
00058 return Tree::dimensions();
00059 }
00060 };
00061
00062
00068 template<typename Tree>
00069 struct KDTreeConstTraits : public KDTreeTraits<Tree> {
00070 typedef typename Tree::const_pointer pointer;
00071 typedef typename Tree::const_reference reference;
00072 };
00073
00074
00075
00076
00089 template <typename Point,
00090 typename Value,
00091 unsigned int Dimensions = 2,
00092 typename Discriminator = unsigned char,
00093 typename Size = unsigned int>
00094 class KDTree {
00095
00096 public:
00097
00098 typedef KDTreeTraits<KDTree> traits;
00099 typedef KDTreeConstTraits<KDTree> const_traits;
00100 typedef Point key_type;
00101 typedef Value mapped_type;
00102 typedef std::pair<const key_type, mapped_type> value_type;
00103 typedef value_type* pointer;
00104 typedef pointer const const_pointer;
00105 typedef value_type& reference;
00106 typedef const reference const_reference;
00107 typedef Discriminator discriminator_type;
00108 typedef detail::KDTreeRangeSearchIterator<traits> iterator;
00109 typedef detail::KDTreeRangeSearchIterator<const_traits> const_iterator;
00110 typedef PointTraits<key_type> point_traits;
00111
00112 typedef Size size_type;
00113
00114 static inline const unsigned int dimensions() {
00115 return Dimensions;
00116 }
00117
00118 typedef detail::KDTreeNode<traits> node_type;
00119 typedef typename node_type::child_type child_type;
00120
00121 private:
00122
00123 struct NodeComparator {
00124 discriminator_type discriminator;
00125
00126 explicit NodeComparator() : discriminator(0) { }
00127
00128 bool operator()(const node_type* const & n1,
00129 const node_type* const & n2) const
00130 {
00131 return (n1->point()[discriminator] < n2->point()[discriminator]);
00132 }
00133 };
00134
00135 node_type *_root;
00136 size_type _size;
00137 iterator _endIterator;
00138 const_iterator _constEndIterator;
00139
00140 node_type * getNode(const key_type & point, node_type **parent = 0) const {
00141 discriminator_type discriminator;
00142 child_type child;
00143 node_type *node = _root, *last = 0;
00144
00145 while(node != 0) {
00146 discriminator = node->discriminator();
00147
00148 if(point[discriminator] > node->point()[discriminator])
00149 child = node_type::ChildHigh;
00150 else if(point[discriminator] < node->point()[discriminator])
00151 child = node_type::ChildLow;
00152 else if(node->point() == point) {
00153 if(parent != 0)
00154 *parent = last;
00155 return node;
00156 } else
00157 child = node_type::ChildHigh;
00158
00159 last = node;
00160 node = node->child(child);
00161 }
00162
00163 if(parent != 0)
00164 *parent = last;
00165
00166 return 0;
00167 }
00168
00169 node_type * getMinimumNode(node_type *node, node_type *p,
00170 const discriminator_type discriminator,
00171 node_type **parent)
00172 {
00173 node_type *result;
00174
00175 if(discriminator == node->discriminator()) {
00176 if(node->child(node_type::ChildLow) != 0)
00177 return
00178 getMinimumNode(node->child(node_type::ChildLow), node,
00179 discriminator, parent);
00180 else
00181 result = node;
00182 } else {
00183 node_type *nlow = 0, *nhigh = 0;
00184 node_type *plow, *phigh;
00185
00186 if(node->child(node_type::ChildLow) != 0)
00187 nlow =
00188 getMinimumNode(node->child(node_type::ChildLow), node,
00189 discriminator, &plow);
00190
00191 if(node->child(node_type::ChildHigh) != 0)
00192 nhigh =
00193 getMinimumNode(node->child(node_type::ChildHigh), node,
00194 discriminator, &phigh);
00195
00196 if(nlow != 0 && nhigh != 0) {
00197 if(nlow->point()[discriminator] < nhigh->point()[discriminator]) {
00198 result = nlow;
00199 *parent = plow;
00200 } else {
00201 result = nhigh;
00202 *parent = phigh;
00203 }
00204 } else if(nlow != 0) {
00205 result = nlow;
00206 *parent = plow;
00207 } else if(nhigh != 0) {
00208 result = nhigh;
00209 *parent = phigh;
00210 } else
00211 result = node;
00212 }
00213
00214 if(result == node)
00215 *parent = p;
00216 else if(node->point()[discriminator] < result->point()[discriminator]) {
00217 result = node;
00218 *parent = p;
00219 }
00220
00221 return result;
00222 }
00223
00224
00225 node_type * recursiveRemoveNode(node_type *node) {
00226 discriminator_type discriminator;
00227 node_type *newRoot, *parent;
00228
00229 if(node->child(node_type::ChildLow) == 0 &&
00230 node->child(node_type::ChildHigh) == 0)
00231 return 0;
00232 else
00233 discriminator = node->discriminator();
00234
00235 if(node->child(node_type::ChildHigh) == 0) {
00236 node->child(node_type::ChildHigh) = node->child(node_type::ChildLow);
00237 node->child(node_type::ChildLow) = 0;
00238 }
00239
00240 newRoot =
00241 getMinimumNode(node->child(node_type::ChildHigh), node,
00242 discriminator, &parent);
00243
00244 child_type child = (parent->child(node_type::ChildLow) == newRoot ?
00245 node_type::ChildLow : node_type::ChildHigh);
00246 parent->child(child) = recursiveRemoveNode(newRoot);
00247
00248 newRoot->child(node_type::ChildLow) = node->child(node_type::ChildLow);
00249 newRoot->child(node_type::ChildHigh) = node->child(node_type::ChildHigh);
00250 newRoot->discriminator() = node->discriminator();
00251
00252 return newRoot;
00253 }
00254
00255
00256 bool add(const key_type & point, const mapped_type & value,
00257 node_type **node, node_type *parent,
00258 mapped_type *replaced = 0)
00259 {
00260 if(parent == 0) {
00261 if(_root != 0)
00262 *node = _root;
00263 else {
00264 _root = *node = new node_type(0, point, value);
00265 ++_size;
00266 return false;
00267 }
00268 } else if(*node == 0) {
00269 discriminator_type discriminator;
00270 child_type child;
00271
00272 discriminator = parent->discriminator();
00273 child = (point[discriminator] >= parent->point()[discriminator] ?
00274 node_type::ChildHigh : node_type::ChildLow);
00275
00276 if(++discriminator >= dimensions())
00277 discriminator = 0;
00278
00279 parent->child(child) = *node =
00280 new node_type(discriminator, point, value);
00281
00282 ++_size;
00283 return false;
00284 }
00285
00286 if(replaced != 0)
00287 *replaced = (*node)->value();
00288
00289 (*node)->value() = value;
00290
00291 return true;
00292 }
00293
00294 template<template<typename> class Container>
00295 static inline
00296 node_type * optimize(typename Container<node_type*>::iterator begin,
00297 typename Container<node_type*>::iterator end,
00298 NodeComparator & comparator)
00299 {
00300 node_type *midpoint = 0;
00301 typename Container<node_type*>::iterator::difference_type diff;
00302
00303 diff = end - begin;
00304
00305 if(diff > 1) {
00306 discriminator_type discriminator = comparator.discriminator;
00307 typename Container<node_type*>::iterator nth = begin + (diff >> 1);
00308 typename Container<node_type*>::iterator nthprev = nth - 1;
00309
00310
00311 stable_sort(begin, end, comparator);
00312
00313
00314 while(nth > begin &&
00315 (*nth)->point()[discriminator] ==
00316 (*nthprev)->point()[discriminator])
00317 {
00318 --nth;
00319 --nthprev;
00320 }
00321
00322 midpoint = *nth;
00323 midpoint->discriminator() = discriminator;
00324
00325 if(++discriminator >= dimensions())
00326 discriminator = 0;
00327
00328 comparator.discriminator = discriminator;
00329
00330
00331 midpoint->child(node_type::ChildLow) =
00332 optimize<Container>(begin, nth, comparator);
00333
00334 comparator.discriminator = discriminator;
00335
00336
00337 midpoint->child(node_type::ChildHigh) =
00338 optimize<Container>(nth + 1, end, comparator);
00339 } else if(diff == 1) {
00340 midpoint = *begin;
00341 midpoint->discriminator() = comparator.discriminator;
00342 midpoint->child(node_type::ChildLow) = 0;
00343 midpoint->child(node_type::ChildHigh) = 0;
00344 }
00345
00346 return midpoint;
00347 }
00348
00349 template<template<typename> class Container>
00350 static inline void fillContainer(Container<node_type*> & c, node_type *node)
00351 {
00352 if(node == 0)
00353 return;
00354 c.push_back(node);
00355 fillContainer(c, node->child(node_type::ChildLow));
00356 fillContainer(c, node->child(node_type::ChildHigh));
00357 }
00358
00359 static inline
00360 void initPoint(key_type & point,
00361 const typename point_traits::coordinate_type & value)
00362 {
00363 for(unsigned int i=0; i < point_traits::dimensions(); ++i)
00364 point[i] = value;
00365 }
00366
00367 static inline const key_type upperBound() {
00368 key_type bound;
00369 initPoint(bound, point_traits::max_coordinate());
00370 return bound;
00371 }
00372
00373 static inline const key_type lowerBound() {
00374 key_type bound;
00375 initPoint(bound, point_traits::min_coordinate());
00376 return bound;
00377 }
00378
00379 public:
00380
00381
00382
00386 explicit KDTree() : _root(0), _size(0), _endIterator(), _constEndIterator()
00387 { }
00388
00394 KDTree(const KDTree & tree) :
00395 _root(0), _size(0), _endIterator(), _constEndIterator()
00396 {
00397 for(const_iterator p = tree.begin(); !p.endOfRange(); ++p)
00398 insert(p->first, p->second);
00399 }
00400
00404 virtual ~KDTree() { delete _root; }
00405
00409 void clear() {
00410 delete _root;
00411 _root = 0;
00412 _size = 0;
00413 }
00414
00420 const size_type size() const {
00421 return _size;
00422 }
00423
00429 const size_type max_size() const {
00430 return std::numeric_limits<size_type>::max();
00431 }
00432
00440 bool empty() const {
00441 return (_root == 0);
00442 }
00443
00451 iterator begin() {
00452 return iterator(lowerBound(), upperBound(), _root);
00453 }
00454
00462 const_iterator begin() const {
00463 return const_iterator(lowerBound(), upperBound(), _root);
00464 }
00465
00471 iterator & end() {
00472 return _endIterator;
00473 }
00474
00480 const const_iterator & end() const {
00481 return _constEndIterator;
00482 }
00483
00496 bool insert(const key_type & point, const mapped_type & value,
00497 mapped_type *replaced = 0)
00498 {
00499 node_type *parent;
00500 node_type *node = getNode(point, &parent);
00501
00502 return add(point, value, &node, parent, replaced);
00503 }
00504
00516 std::pair<iterator,bool> insert(const value_type & mapping) {
00517
00518
00519 bool replaced;
00520 mapped_type existing;
00521 iterator value;
00522
00523 replaced = insert(mapping.first, mapping.second, &existing);
00524 value = find(mapping.first);
00525
00526 if(replaced)
00527 value._node->value() = existing;
00528
00529 return std::pair<iterator,bool>(value,!replaced);
00530 }
00531
00532
00542 mapped_type & operator[](const key_type & point) {
00543 node_type *parent;
00544 node_type *node = getNode(point, &parent);
00545
00546 if(node == 0)
00547 add(point, mapped_type(), &node, parent);
00548
00549 return node->value();
00550 }
00551
00562 bool remove(const key_type & point, mapped_type *erased = 0) {
00563 node_type *parent;
00564 node_type *node = getNode(point, &parent);
00565 node_type *child;
00566
00567 if(node == 0)
00568 return false;
00569
00570 if(erased != 0)
00571 *erased = node->value();
00572
00573 child = node;
00574 node = recursiveRemoveNode(child);
00575
00576 if(parent == 0)
00577 _root = node;
00578 else if(child == parent->child(node_type::ChildLow))
00579 parent->child(node_type::ChildLow) = node;
00580 else
00581 parent->child(node_type::ChildHigh) = node;
00582
00583
00584 child->child(node_type::ChildLow) = 0;
00585 child->child(node_type::ChildHigh) = 0;
00586
00587 --_size;
00588 delete child;
00589
00590 return true;
00591 }
00592
00599 size_type erase(const key_type & point) {
00600 return remove(point);
00601 }
00602
00608 void erase(iterator pos) {
00609 remove(pos->first);
00610 }
00611
00625 iterator begin(const key_type & lower, const key_type & upper) {
00626 return iterator(lower, upper, _root);
00627 }
00628
00642 const_iterator begin(const key_type & lower, const key_type & upper) const {
00643 return const_iterator(lower, upper, _root);
00644 }
00645
00659 bool get(const key_type & point, mapped_type *value = 0) const {
00660 node_type *node = getNode(point);
00661
00662 if(node == 0)
00663 return false;
00664 else if(value != 0)
00665 *value = node->value();
00666
00667 return true;
00668 }
00669
00680 iterator find(const key_type & point) {
00681 return iterator(point, upperBound(), _root, true);
00682 }
00683
00694 const_iterator find(const key_type & point) const {
00695 return const_iterator(point, upperBound(), _root, true);
00696 }
00697
00698
00706 void optimize() {
00707 if(empty())
00708 return;
00709
00710 typedef std::vector<node_type*> container;
00711 container nodes;
00712
00713 nodes.reserve(size());
00714 fillContainer<std::vector>(nodes, _root);
00715
00716 NodeComparator comparator;
00717 _root =
00718 optimize<std::vector>(nodes.begin(), nodes.end(), comparator);
00719 }
00720
00721 };
00722
00723
00724 __END_PACKAGE_SAVA_SPATIAL
00725
00726 #endif
00727
Copyright © 2003-2005 Savarese Software Research and Daniel F. Savarese. All rights reserved.