Browse Source

add sum to segment tree

Olivier Marty 8 years ago
parent
commit
4af95fd1e0
3 changed files with 82 additions and 53 deletions
  1. 2 1
      code/Makefile
  2. 31 21
      code/SegmentTree.cpp
  3. 49 31
      code/SegmentTree_test.cpp

+ 2 - 1
code/Makefile

@@ -1,6 +1,7 @@
 TEST_SRC=$(wildcard *_test.cpp)
 TEST_EXE=$(TEST_SRC:.cpp=)
 TEST_RUN=$(TEST_SRC:.cpp=.run)
+CC=g++ -std=c++11
 
 run: $(TEST_RUN)
 
@@ -11,7 +12,7 @@ run: $(TEST_RUN)
 
 %_test: %_test.cpp %.cpp
 	@echo Compiling $@...
-	@g++ -o $@ $<
+	@$(CC) -o $@ $<
 
 clean:
 	rm -f $(TEST_EXE)

+ 31 - 21
code/SegmentTree.cpp

@@ -1,28 +1,37 @@
-// segment tree for minimum
+// segment tree for minimum and sum
 // root is in tree[1], children in tree[2i] and tree[2i+1]
 // all ranges in build/updates/queries are 0-based
 const int INFI=1000000000;
 struct Node {
-  int v; int up;
+  int min, sum, up, size; // size of subtree
   void update(int x) {
-    v += x;
+    min += x;
+    sum += x*size;
     up += x;
   }
   Node() {
-    v = INFI;
-    up = 0;
+    min = INFI;
+    up = sum = size = 0;
   }
 };
 Node *tree;
 int N;
+void compute(int v) { // update tree[v] from its children, v < N
+  tree[v].min = min(tree[2*v].min, tree[2*v+1].min);
+  tree[v].sum = tree[2*v].sum + tree[2*v+1].sum;
+}
 // read values[0..size-1]
 void build(int size, int values[]) {
-  N = 1 << ((int)log2(size-1)+1);
+  N = size > 1 ? 1 << ((int)log2(size-1)+1) : 1;
   tree = new Node[2*N];
-  for(int i = 0; i < size; i++) // leaves
-    tree[N+i].v = values[i];
-  for(int i = N-1; i > 0; i--) // interns
-    tree[i].v = min(tree[2*i].v, tree[2*i+1].v);
+  for(int i = 0; i < size; i++) { // leaves
+    tree[N+i].min = tree[N+i].sum = values[i];
+    tree[N+i].size = 1;
+  }
+  for(int i = N-1; i > 0; i--) { // interns
+    compute(i);
+    tree[i].size = tree[2*i].size + tree[2*i+1].size;
+  }
 }
 void push(int v) {
   if(2*v < 2*N) // left subtree
@@ -32,19 +41,20 @@ void push(int v) {
   tree[v].up = 0;
 }
 // v: current vertex with corressponding range [left, right)
-// find mimimum in the rangeg [l, r)
-int query_aux(int v, int left, int right, int l, int r) {
+// find mimimum and sum of the range [l, r)
+pair<int, int> query_aux(int v, int left, int right, int l, int r) {
   push(v);
   if(right <= l || r <= left) // outside
-    return INFI;
+    return {INFI, 0};
   if(l <= left && right <= r) // inside
-    return tree[v].v;
+    return {tree[v].min, tree[v].sum};
   int m = (left+right)/2;
-  int left_min = query_aux(2*v, left, m, l, r); // left subtree
-  int right_min = query_aux(2*v+1, m, right, l, r); // right
-  return min(left_min, right_min);
+  int lm, rm, ls, rs;
+  tie(lm, ls) = query_aux(2*v, left, m, l, r); // left subtree
+  tie(rm, rs) = query_aux(2*v+1, m, right, l, r); // right
+  return {min(lm, rm), ls+rs};
 }
-int query(int l, int r) {
+pair<int, int> query(int l, int r) {
   return query_aux(1, 0, N, l, r);
 }
 // update element at index i with value x
@@ -59,9 +69,9 @@ void update(int i, int x) {
       v++; // i on right
     po /= 2;
   }
-  tree[i].v = x; // update el
+  tree[i].min = tree[i].sum = x; // update el
   for(i /= 2; i > 0; i /= 2) // update all segments containing i
-    tree[i].v = min(tree[2*i].v, tree[2*i+1].v);
+    compute(i);
 }
 // v: the current vertex with corressponding range in [left, right)
 // add value x for element in range [l, r)
@@ -76,7 +86,7 @@ void update_range_aux(int v, int left, int right, int l, int r, int x) {
   int m = (left+right)/2;
   update_range_aux(2*v, left, m, l, r, x);
   update_range_aux(2*v+1, m, right, l, r, x);
-  tree[v].v = min(tree[2*v].v, tree[2*v+1].v);
+  compute(v);
 }
 // update [l,r)
 void update_range(int l, int r, int x) {

+ 49 - 31
code/SegmentTree_test.cpp

@@ -1,45 +1,63 @@
 #include <bits/stdc++.h>
 using namespace std;
-int min(int a, int b) {
-  return a<b ? a : b;
-}
+
 #include "SegmentTree.cpp"
+
 int main() {
   int values[] = {1,2,3,4,5,6,7};
   build(sizeof(values)/sizeof(int), values);
-  assert(query(1,5) == 2);
-  assert(query(0,5) == 1);
-  assert(query(2,8) == 3);
-  assert(query(4,5) == 5);
+  assert(query(1,5) == make_pair(2, 14));
+  assert(query(0,5) == make_pair(1, 15));
+  assert(query(2,8) == make_pair(3, 25));
+  assert(query(4,5) == make_pair(5, 5));
   update_range(0, 7, 1);
-  assert(query(1,5) == 3);
-  assert(query(0,7) == 2);
-  assert(query(2,8) == 4);
-  assert(query(4,5) == 6);
+  assert(query(1,5) == make_pair(3, 18));
+  assert(query(0,7) == make_pair(2, 35));
+  assert(query(2,8) == make_pair(4, 30));
+  assert(query(4,5) == make_pair(6, 6));
   update_range(2, 6, 2);
   update_range(3, 7, 2);
-  assert(query(1,5) == 3);
-  assert(query(0,7) == 2);
-  assert(query(2,8) == 6);
-  assert(query(4,5) == 10);
-  assert(query(5,7) == 10);
+  assert(query(1,5) == make_pair(3, 28));
+  assert(query(0,7) == make_pair(2, 51));
+  assert(query(2,8) == make_pair(6, 46));
+  assert(query(4,5) == make_pair(10, 10));
+  assert(query(5,7) == make_pair(10, 21));
   update(2, 0);
-  assert(query(1,5) == 0);
-  assert(query(0,7) == 0);
-  assert(query(2,8) == 0);
-  assert(query(4,5) == 10);
-  assert(query(5,7) == 10);
+  assert(query(1,5) == make_pair(0, 22));
+  assert(query(0,7) == make_pair(0, 45));
+  assert(query(2,8) == make_pair(0, 40));
+  assert(query(4,5) == make_pair(10, 10));
+  assert(query(5,7) == make_pair(10, 21));
   update_range(1, 4, 2);
-  assert(query(1,5) == 2);
-  assert(query(0,7) == 2);
-  assert(query(4,8) == 10);
-  assert(query(4,5) == 10);
-  assert(query(6,7) == 10);
-  assert(query(1,3) == 2);
+  assert(query(1,5) == make_pair(2, 28));
+  assert(query(0,7) == make_pair(2, 51));
+  assert(query(4,8) == make_pair(10, 31));
+  assert(query(4,5) == make_pair(10, 10));
+  assert(query(6,7) == make_pair(10, 10));
+  assert(query(1,3) == make_pair(2, 7));
   update(5, -1);
-  assert(query(0, 5) == 2);
-  assert(query(6, 7) == 10);
-  assert(query(5, 6) == -1);
-  assert(query(4, 7) == -1);
+  assert(query(0, 5) == make_pair(2, 30));
+  assert(query(6, 7) == make_pair(10, 10));
+  assert(query(5, 6) == make_pair(-1, -1));
+  assert(query(4, 7) == make_pair(-1, 19));
+  update(5, 3);
+  update_range(4, 7, 1);
+  assert(query(4, 7) == make_pair(4, 26));
+  assert(query(0, 7) == make_pair(2, 46));
+  delete[] tree;
+
+  int values2[] = {4};
+  build(sizeof(values2)/sizeof(int), values2);
+  assert(query(0,1) == make_pair(4, 4));
+  update(0, 2);
+  assert(query(0,1) == make_pair(2, 2));
+  update_range(0, 1, 7);
+  assert(query(0,1) == make_pair(9, 9));
+  update(0, 3);
+  assert(query(0,1) == make_pair(3, 3));
+  update_range(0, 1, 7);
+  update(0, 3);
+  assert(query(0,1) == make_pair(3, 3));
+  delete[] tree;
   return 0;
 }