KDTree.cc 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. // --------------------------------------------------------------------------
  2. // A straightforward, but probably sub-optimal KD-tree implmentation that's
  3. // probably good enough for most things (current it's a 2D-tree)
  4. //
  5. // - constructs from n points in O(n lg^2 n) time
  6. // - handles nearest-neighbor query in O(lg n) if points are well distributed
  7. // - worst case for nearest-neighbor may be linear in pathological case
  8. //
  9. // Sonny Chan, Stanford University, April 2009
  10. // --------------------------------------------------------------------------
  11. #include <iostream>
  12. #include <vector>
  13. #include <limits>
  14. #include <cstdlib>
  15. using namespace std;
  16. // number type for coordinates, and its maximum value
  17. typedef long long ntype;
  18. const ntype sentry = numeric_limits<ntype>::max();
  19. // point structure for 2D-tree, can be extended to 3D
  20. struct point {
  21. ntype x, y;
  22. point(ntype xx = 0, ntype yy = 0) : x(xx), y(yy) {}
  23. };
  24. bool operator==(const point &a, const point &b)
  25. {
  26. return a.x == b.x && a.y == b.y;
  27. }
  28. // sorts points on x-coordinate
  29. bool on_x(const point &a, const point &b)
  30. {
  31. return a.x < b.x;
  32. }
  33. // sorts points on y-coordinate
  34. bool on_y(const point &a, const point &b)
  35. {
  36. return a.y < b.y;
  37. }
  38. // squared distance between points
  39. ntype pdist2(const point &a, const point &b)
  40. {
  41. ntype dx = a.x-b.x, dy = a.y-b.y;
  42. return dx*dx + dy*dy;
  43. }
  44. // bounding box for a set of points
  45. struct bbox
  46. {
  47. ntype x0, x1, y0, y1;
  48. bbox() : x0(sentry), x1(-sentry), y0(sentry), y1(-sentry) {}
  49. // computes bounding box from a bunch of points
  50. void compute(const vector<point> &v) {
  51. for (int i = 0; i < v.size(); ++i) {
  52. x0 = min(x0, v[i].x); x1 = max(x1, v[i].x);
  53. y0 = min(y0, v[i].y); y1 = max(y1, v[i].y);
  54. }
  55. }
  56. // squared distance between a point and this bbox, 0 if inside
  57. ntype distance(const point &p) {
  58. if (p.x < x0) {
  59. if (p.y < y0) return pdist2(point(x0, y0), p);
  60. else if (p.y > y1) return pdist2(point(x0, y1), p);
  61. else return pdist2(point(x0, p.y), p);
  62. }
  63. else if (p.x > x1) {
  64. if (p.y < y0) return pdist2(point(x1, y0), p);
  65. else if (p.y > y1) return pdist2(point(x1, y1), p);
  66. else return pdist2(point(x1, p.y), p);
  67. }
  68. else {
  69. if (p.y < y0) return pdist2(point(p.x, y0), p);
  70. else if (p.y > y1) return pdist2(point(p.x, y1), p);
  71. else return 0;
  72. }
  73. }
  74. };
  75. // stores a single node of the kd-tree, either internal or leaf
  76. struct kdnode
  77. {
  78. bool leaf; // true if this is a leaf node (has one point)
  79. point pt; // the single point of this is a leaf
  80. bbox bound; // bounding box for set of points in children
  81. kdnode *first, *second; // two children of this kd-node
  82. kdnode() : leaf(false), first(0), second(0) {}
  83. ~kdnode() { if (first) delete first; if (second) delete second; }
  84. // intersect a point with this node (returns squared distance)
  85. ntype intersect(const point &p) {
  86. return bound.distance(p);
  87. }
  88. // recursively builds a kd-tree from a given cloud of points
  89. void construct(vector<point> &vp)
  90. {
  91. // compute bounding box for points at this node
  92. bound.compute(vp);
  93. // if we're down to one point, then we're a leaf node
  94. if (vp.size() == 1) {
  95. leaf = true;
  96. pt = vp[0];
  97. }
  98. else {
  99. // split on x if the bbox is wider than high (not best heuristic...)
  100. if (bound.x1-bound.x0 >= bound.y1-bound.y0)
  101. sort(vp.begin(), vp.end(), on_x);
  102. // otherwise split on y-coordinate
  103. else
  104. sort(vp.begin(), vp.end(), on_y);
  105. // divide by taking half the array for each child
  106. // (not best performance if many duplicates in the middle)
  107. int half = vp.size()/2;
  108. vector<point> vl(vp.begin(), vp.begin()+half);
  109. vector<point> vr(vp.begin()+half, vp.end());
  110. first = new kdnode(); first->construct(vl);
  111. second = new kdnode(); second->construct(vr);
  112. }
  113. }
  114. };
  115. // simple kd-tree class to hold the tree and handle queries
  116. struct kdtree
  117. {
  118. kdnode *root;
  119. // constructs a kd-tree from a points (copied here, as it sorts them)
  120. kdtree(const vector<point> &vp) {
  121. vector<point> v(vp.begin(), vp.end());
  122. root = new kdnode();
  123. root->construct(v);
  124. }
  125. ~kdtree() { delete root; }
  126. // recursive search method returns squared distance to nearest point
  127. ntype search(kdnode *node, const point &p)
  128. {
  129. if (node->leaf) {
  130. // commented special case tells a point not to find itself
  131. // if (p == node->pt) return sentry;
  132. // else
  133. return pdist2(p, node->pt);
  134. }
  135. ntype bfirst = node->first->intersect(p);
  136. ntype bsecond = node->second->intersect(p);
  137. // choose the side with the closest bounding box to search first
  138. // (note that the other side is also searched if needed)
  139. if (bfirst < bsecond) {
  140. ntype best = search(node->first, p);
  141. if (bsecond < best)
  142. best = min(best, search(node->second, p));
  143. return best;
  144. }
  145. else {
  146. ntype best = search(node->second, p);
  147. if (bfirst < best)
  148. best = min(best, search(node->first, p));
  149. return best;
  150. }
  151. }
  152. // squared distance to the nearest
  153. ntype nearest(const point &p) {
  154. return search(root, p);
  155. }
  156. };
  157. // --------------------------------------------------------------------------
  158. // some basic test code here
  159. int main()
  160. {
  161. // generate some random points for a kd-tree
  162. vector<point> vp;
  163. for (int i = 0; i < 100000; ++i) {
  164. vp.push_back(point(rand()%100000, rand()%100000));
  165. }
  166. kdtree tree(vp);
  167. // query some points
  168. for (int i = 0; i < 10; ++i) {
  169. point q(rand()%100000, rand()%100000);
  170. cout << "Closest squared distance to (" << q.x << ", " << q.y << ")"
  171. << " is " << tree.nearest(q) << endl;
  172. }
  173. return 0;
  174. }
  175. // --------------------------------------------------------------------------