Segment Trees (Recursive)

Written by Timon Gehr.

Divide and Conquer on Ranges

Consider the following generic divide-and-conquer algorithm, where solve(l, r) operates on a segment \(a_l,\ldots,a_{r-1}\) of an array \(a\):

T solve(int l, int r) {
    if (l + 1 == r) return a[l];
    int m = l + (r-l)/2;
    return conquer(solve(l, m), solve(m, r));
}

int n = a.size();
auto result = solve(0, n);

For example, if \(a\) is an array of integers, we can compute its maximum as follows:

int largest(int l, int r) {
    if (l + 1 == r) return a[l];
    int m = l + (r-l)/2;
    return max(largest(l, m), largest(m, r));
}

int n = a.size();
int result = largest(0, n);

If we let \(a = [5, 4, 2, 6, 9, 1, 4, 10]\), this results in the following recursion tree. In each node, we write the result of the respective recursive call of largest and a parent node calls its children in order from left to right.

results:

                  9
               /     \

         6               9
       /   \           /    \
    5        6       9       3
   /  \     /  \    /  \    /  \
  5    4   2   6   9    1  4   3

We can also draw a similar recursion tree labeling nodes with arguments \([l, r)\) instead of return values:

arguments [l, r):

                 [0, 8)
                /     \

       [0, 4)              [4, 8)
       /    \              /    \
  [0, 2)    [2, 4)    [4, 6)    [6, 8)
   /  \      /  \      /  \      /  \
[0,1)[1,2)[2,3)[3,4)[4,5)[5,6)[6,7)[7,8)

Note that this tree depends only on the length of \(a\), but not its contents.

Now let’s change \(a_{2}\) from \(2\) to \(10\) and draw the new recursion tree of results when we call largest:

results:

                10*
              /     \

       10*              9
      /   \           /    \
   5        10*     9       3
  /  \     /  \    /  \    /  \
 5    4  10*  6   9    1  4   3

The nodes that changed are annotated with an asterisk (*). Note how only nodes on a path from a leaf to the root are different than before. We can therefore optimize our second computation by storing all intermediate values in a cache and only changing the ones that actually need to be changed. We can then chain many such updates of the array \(a\) and always get the current maximum in time \(O(\log n)\).

Building a Segment Tree

Such a recursion tree, storing intermediate return values in its leaves, is called a segment tree. We can build it by augmenting our solve function with a cache \(b\):

vector<T> b;
T build(int l, int r, int i) {
    if (l + 1 == r) return b[i] = a[l];
    int m = l + (r-l)/2;
    return b[i] = conquer(build(l, m, 2*i+1), build(m, r, 2*i+2));
}
b.resize(4 * n);
build(0, n, 0);

The additional parameter \(i\) serves as an index into the cache. We make sure that every node gets a different index. (In principle, we could also use a std::unordered_map<pair<int,int>, T> mapping ranges \([l, r)\) to the return value, but this way is more efficient.)

In our running example, the argument \(i\) takes on the following values:

argument i:

                  0
               /     \

         1               2
       /   \           /    \
    3        4       5       6
   /  \     /  \    /  \    /  \
  7    8   9   10  11  12  13  14

This works because the binary representation of \(i+1\) encodes the path from the root to each node. A \(0\) means “go left” while a \(1\) means “go right”. (We have \((2\cdot i+1)+1 = 2\cdot (i+1)\) and \((2\cdot i+2)+1 = 2\cdot (i+1)+1\).)

binary representation of i+1:

                   1
                /     \

         10                 11
       /    \             /    \
   100       101       110      111
   /  \      /  \      /  \     /  \
1000 1001 1010 1011 1100 1101 1110 1111

We have chosen the length of the cache as \(4\) times the length \(n\) of the underlying array. (The required length of the cache never decreases when \(n\) increases. If \(n\) is a power of two, a cache length of \(2\cdot n\) would be sufficient and there is always at least one power of two between \(n\) and \(2\cdot n\), therefore \(4\cdot n\) is large enough.)

This also implies that the function call build(0, n, 0) calls conquer \(O(n)\) times.

Updating Elements

Whenever we update the entry \(a_k\) of the array \(a\), we have to recompute the entries in the cache \(b\) that depend on this value. This is precisely the set of nodes corresponding to calls with segments \([l, r)\) containing \(k\). The following function realizes such an update:

// set a[k] to x and update segment tree
T update(int l, int r, int i, int k, int x) {
    if (!(l<=k && k<r)) return b[i]; // entry does not depend on a[k], use cached value
    if (l + 1 == r) return b[i] = a[k] = x; // leaf update
    int m = l + (r-l)/2;
    return b[i] = conquer(update(l, m, 2*i+1, k, x), update(l, m, 2*i+2, k, x)); // update cache
}

int result = update(0, n, 0, k, x); // example of usage

In essence, this is just the original divide-and-conquer algorithm solve, except that it reads values that do not depend on \(a_k\) from the cache. It recomputes all other values. The function updates one node on each level of the segment tree and calls conquer \(O(\log n)\) times.

Range Queries

So far, we have shown how to perform updates and recompute the result of the divide-and-conquer algorithm on the entire input array. However, if conquer is an associative operation (i.e., conquer(conquer(a, b), c) == conquer(a, conquer(b, c))), segment trees also allow results to be computed for sub-ranges. (For example, we can answer queries of the form: “what is the maximum value among \(a_{\mathit{ql}},\ldots,a_{\mathit{qr}-1}\)?”)

In the following, we assume there is a value neutral_element such that conquer(a, neutral_element) == a and conquer(neutral_element, b) == b. It is always possible to modify conquer such that it has such a neutral element. Many operations naturally have neutral elements, for example \(-\infty\) is a neutral element for \(\max\) and \(0\) is a neutral element for \(+\). The interpretation of the neutral element is that it is the result of solve if the input is empty:

T solve(int l, int r) {
    if (l >= r) return neutral_element;
    if (l + 1 == r) return a[l];
    int m = l + (r-l)/2;
    return conquer(solve(l, m), solve(m, r));
}

The implementation is again a modified version of solve, which greedily selects and combines a set of segment tree nodes that together cover the range \([\mathit{ql},\mathit{qr})\). (This set of segments is also called the canonical cover of \([\mathit{ql},\mathit{qr})\).)

// compute solve(ql, qr) on a segment tree representing the recursion tree of solve(l, r)
T query_range(int l, int r, int i, int ql, int qr) {
    if (r <= ql || qr <= l) return neutral_element; // range does not intersect query range
    if (ql <= l && r <= qr) return b[i]; // range contained in query range, use cached value
    int m = l + (r-l)/2;
    return conquer(query_range(l, m, 2*i+1, ql, qr), query_range(m, r, 2*i+2, ql, qr));
}

int result = query_range(0, n, 0, ql, qr); // example of usage

Ignoring calls returning after the first if statement, the number of segment tree nodes that are accessed by query_range is at most two per level of the segment tree. This is because on any level, the segment tree nodes whose ranges intersect the query range are contiguous. Except for possibly the first and last such node, all such nodes will be fully covered by \([\mathit{ql},\mathit{qr})\) and be siblings of other nodes that are fully covered. Therefore, those internal nodes will not be accessed, because query_range will use the cached value of their parent instead of recursing deeper down in the tree.

Therefore, query_range calls conquer at most \(O(\log n)\) times.