KDTree.cc 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // A straightforward, but probably sub-optimal KD-tree implmentation that's
  2. // probably good enough for most things (current it's a 2D-tree)
  3. // - constructs from n points in O(n lg^2 n) time
  4. // - handles nearest-neighbor query in O(lg n) if points are well distributed
  5. // - worst case for nearest-neighbor may be linear in pathological case
  6. // Sonny Chan, Stanford University, April 2009
  7. // number type for coordinates, and its maximum value
  8. typedef long long ntype;
  9. const ntype sentry = numeric_limits<ntype>::max();
  10. // point structure for 2D-tree, can be extended to 3D
  11. struct point {
  12. ntype x, y;
  13. point(ntype xx = 0, ntype yy = 0) : x(xx), y(yy) {}
  14. };
  15. bool operator==(const point &a, const point &b) {
  16. return a.x == b.x && a.y == b.y;
  17. }
  18. // sorts points on x-coordinate
  19. bool on_x(const point &a, const point &b) {
  20. return a.x < b.x;
  21. }
  22. // sorts points on y-coordinate
  23. bool on_y(const point &a, const point &b) {
  24. return a.y < b.y;
  25. }
  26. // squared distance between points
  27. ntype pdist2(const point &a, const point &b) {
  28. ntype dx = a.x-b.x, dy = a.y-b.y;
  29. return dx*dx + dy*dy;
  30. }
  31. // bounding box for a set of points
  32. struct bbox {
  33. ntype x0, x1, y0, y1;
  34. bbox() : x0(sentry), x1(-sentry), y0(sentry), y1(-sentry) {}
  35. // computes bounding box from a bunch of points
  36. void compute(const vector<point> &v) {
  37. for (int i = 0; i < v.size(); ++i) {
  38. x0 = min(x0, v[i].x); x1 = max(x1, v[i].x);
  39. y0 = min(y0, v[i].y); y1 = max(y1, v[i].y);
  40. }
  41. }
  42. // squared distance between a point and this bbox, 0 if inside
  43. ntype distance(const point &p) {
  44. if (p.x < x0) {
  45. if (p.y < y0) return pdist2(point(x0, y0), p);
  46. else if (p.y > y1) return pdist2(point(x0, y1), p);
  47. else return pdist2(point(x0, p.y), p);
  48. }
  49. else if (p.x > x1) {
  50. if (p.y < y0) return pdist2(point(x1, y0), p);
  51. else if (p.y > y1) return pdist2(point(x1, y1), p);
  52. else return pdist2(point(x1, p.y), p);
  53. }
  54. else {
  55. if (p.y < y0) return pdist2(point(p.x, y0), p);
  56. else if (p.y > y1) return pdist2(point(p.x, y1), p);
  57. else return 0;
  58. }
  59. }
  60. };
  61. // stores a single node of the kd-tree, either internal or leaf
  62. struct kdnode {
  63. bool leaf; // true if this is a leaf node (has one point)
  64. point pt; // the single point of this is a leaf
  65. bbox bound; // bounding box for set of points in children
  66. kdnode *first, *second; // two children of this kd-node
  67. kdnode() : leaf(false), first(0), second(0) {}
  68. ~kdnode() { if (first) delete first; if (second) delete second; }
  69. // intersect a point with this node (returns squared distance)
  70. ntype intersect(const point &p) {
  71. return bound.distance(p);
  72. }
  73. // recursively builds a kd-tree from a given cloud of points
  74. void construct(vector<point> &vp) {
  75. // compute bounding box for points at this node
  76. bound.compute(vp);
  77. // if we're down to one point, then we're a leaf node
  78. if (vp.size() == 1) {
  79. leaf = true;
  80. pt = vp[0];
  81. }
  82. else {
  83. // split on x if the bbox is wider than high (not best heuristic...)
  84. if (bound.x1-bound.x0 >= bound.y1-bound.y0)
  85. sort(vp.begin(), vp.end(), on_x);
  86. // otherwise split on y-coordinate
  87. else
  88. sort(vp.begin(), vp.end(), on_y);
  89. // divide by taking half the array for each child
  90. // (not best performance if many duplicates in the middle)
  91. int half = vp.size()/2;
  92. vector<point> vl(vp.begin(), vp.begin()+half);
  93. vector<point> vr(vp.begin()+half, vp.end());
  94. first = new kdnode(); first->construct(vl);
  95. second = new kdnode(); second->construct(vr);
  96. }
  97. }
  98. };
  99. // simple kd-tree class to hold the tree and handle queries
  100. struct kdtree {
  101. kdnode *root;
  102. // constructs a kd-tree from a points (copied here, as it sorts them)
  103. kdtree(const vector<point> &vp) {
  104. vector<point> v(vp.begin(), vp.end());
  105. root = new kdnode();
  106. root->construct(v);
  107. }
  108. ~kdtree() { delete root; }
  109. // recursive search method returns squared distance to nearest point
  110. ntype search(kdnode *node, const point &p) {
  111. if (node->leaf) {
  112. // commented special case tells a point not to find itself
  113. // if (p == node->pt) return sentry;
  114. // else
  115. return pdist2(p, node->pt);
  116. }
  117. ntype bfirst = node->first->intersect(p);
  118. ntype bsecond = node->second->intersect(p);
  119. // choose the side with the closest bounding box to search first
  120. // (note that the other side is also searched if needed)
  121. if (bfirst < bsecond) {
  122. ntype best = search(node->first, p);
  123. if (bsecond < best)
  124. best = min(best, search(node->second, p));
  125. return best;
  126. }
  127. else {
  128. ntype best = search(node->second, p);
  129. if (bfirst < best)
  130. best = min(best, search(node->first, p));
  131. return best;
  132. }
  133. }
  134. // squared distance to the nearest
  135. ntype nearest(const point &p) {
  136. return search(root, p);
  137. }
  138. };
  139. // some basic test code here
  140. int main() {
  141. // generate some random points for a kd-tree
  142. vector<point> vp;
  143. for (int i = 0; i < 100000; ++i) {
  144. vp.push_back(point(rand()%100000, rand()%100000));
  145. }
  146. kdtree tree(vp);
  147. // query some points
  148. for (int i = 0; i < 10; ++i) {
  149. point q(rand()%100000, rand()%100000);
  150. cout << "Closest squared distance to (" << q.x << ", " << q.y << ")"
  151. << " is " << tree.nearest(q) << endl;
  152. }
  153. return 0;
  154. }