DP optimization

Written by Daniel Rutschmann. Translated by Johannes Kapfhammer.

Refer to the German page for a full article.

Convex Hull Trick

The convex hull trick applies to any DP that looks like this:

\begin{equation*} \mathtt{DP}[i]=a[i] + \max_{k\le i}\{m[k]\cdot x[i]+q[k]\} \end{equation*}
\(m[k]\) and \(q[k]\) are constants only depending on \(k\) (not \(i\)) and have to be known as soon as \(i=k\).
\(a[i]\) and \(x[i]\) are constants only depending on \(i\) (not \(k\)) and have to be known right before \(\mathtt{DP}[i]\) is computed.

Example: Commando

Before we show how the trick works, we solve the task Commando.

The DP looks like this:

\begin{equation*} \mathtt{DP}[i] = \max_{k\le i} \{a\cdot(\sum_{j=k}^i x[j])^2 + b\cdot(\sum_{j=k}^i x[j]) + c + dp[k-1]\} \end{equation*}

This looks quadratic which and not at all in the form above, but we can simplify it by computing the prefix sums \(s[i+1]=x[i+1]+s[i]\) (with \(s[0]=0\)), multiplying out and collecting the terms:

\begin{align*} \mathtt{DP}[i] &= \max_{k\le i} \{a\cdot (s[i+1] - s[k])^2 + b\cdot(s[i+1] - s[k]) + c + dp[k-1]\}\\ &= \max_{k\le i} \{a\cdot s[i+1]^2 + a\cdot s[k]^2 - 2\cdot a\cdot s[i]\cdot s[k] + b\cdot s[i+1] - b\cdot s[k] + c + dp[k-1]\}\\ &= \underbrace{a\cdot s[i+1]^2 + b\cdot s[i+1] + c}_{a[i]} + \max_{k\le i} \{\underbrace{-2\cdot a\cdot s[k]}_{m[k]}\cdot \underbrace{s[i]}_{x[i]} + \underbrace{a\cdot s[k]^2 - b\cdot s[k] + dp[k-1]}_{q[k]}\}\\ \end{align*}

Convex Hull Data Structure

To speed up the DP equation

\begin{equation*} \mathtt{DP}[i]=a[i] + \max_{k\le i}\{m[k]\cdot x[i]+q[k]\} \end{equation*}

we implement a hull data structure that supports the following queries:

  • \(\mathtt{insert}(m, q)\): inserting a new line \(y=m\cdot x+q\) in \(\mathcal O(\log n)\)
  • \(\mathtt{query}(x)\): computing \(\max_\ell\{m_\ell\cdot x+q_\ell\}\) in \(\mathcal O(\log n)\)

This allows us to solve the recurrence by repeatedly adding new lines and querying for the maximum:

hull h;
vector<long long> dp(n);
for (int i = 0; i < n; ++i) {
  h.insert({m[i], q[i]});
  dp[i] = a[i] + h.query(x[i]);
}

This data structure works by dynamically storing the convex hull of linear functions and finding the value at a given \(x\) by binary searching to find the line that is currently dominating.

Deque Implementation

Quite often, the DP equation of the convex hull trick

\begin{equation*} \mathtt{DP}[i]=a[i] + \max_{k\le i}\{m[k]\cdot x[i]+q[k]\} \end{equation*}

additionally satisfies

  • \(m[j+1]>m[j]\) (slopes are increasing) and
  • \(x[j+1]\ge x[j]\) (queries are increasing).

This is the case in our example problem above: The queries at \(x[i]=s[i]\) are increasing because \(s\) is the prefix sum of positive values. Because \(a<0\), the slope \(m[i]=-2\cdot a\cdot s[i]\) is also increasing.

The increasing slopes simplify the insertion: We know that the line must be added to the right side. So we can remove lines as long as they are hidden (similar to the convex hull idea), and then add the new line. An update is amortized \(\mathcal O(1)\).

The same trick can be applied for the query: We know that whereever the current maximal line is, we can never “go back” afterwards, so we can just remove the lines at the beginning. Such a query is also amortized \(\mathcal O(1)\).

Below is code that case using a std::queue.

struct line {
  long long m, q;
  long long eval(long long x) const { return m*x + q; }
};

// check if l2 is completely below max(l1, l3)
// requires that l1.m < l2.m < l2.m
bool bad(line const &l1, line const &l2, line const &l3) {
  // or long double if __int128 is not available
  return ((__int128)l2.q - l3.q) * (l2.m - l1.m) <=
         ((__int128)l1.q - l2.q) * (l3.m - l2.m);
}

struct hull {
  deque<line> slopes;

  // insert line to hull
  void insert(line l) {
    assert(slopes.empty() || l.m > slopes.back().m);
    while (slopes.size() > 1 &&
           bad(slopes.rbegin()[1], slopes.rbegin()[0], l))
      slopes.pop_back();
    slopes.push_back(l);
  }

  // maximum at x:  max_l l.m*x + l.q
  long long query(long long x) {
    assert(!slopes.empty());
    while (slopes.size() > 1 && slopes[0].eval(x) < slopes[1].eval(x))
      slopes.pop_front();
    return slopes.front().eval(x);
  }
};

To support arbitrary queries, the query function can perform a binary search with upper_bound, similar to the std::multiset implementation.

In case \(m[j+1]=m[j]\) can happen, there needs to be a collinearity check at the beginning of insert.

Tasks:

Arbitrary Queries and Inserts

To support inserts in the middle and arbitrary queries, we need something more dynamic. We can avoid implementing our own balanced binary search tree by exploiting that multiset supports different comparison operators for insert and calls to upper_bound.

As we have to binary search over the lines, we additionally need to remember xleft, the left endpoint of each segment in the hull. As we need to change that during an insert, we make it mutable.

struct line {
  long long m, q; // y = m*x + q
  // x from which this line becomes relevant in the hull
  // mutable means we can change xleft on a const-reference to line
  mutable long double xleft = numeric_limits<long double>::quiet_NaN();
};
bool operator<(line const& a, line const& b) { // sort lines after m
  return make_pair(a.m, a.q) < make_pair(b.m, b.q);
}
bool operator<(long long x, line const& l) { // binary search for x
  return x < l.xleft;
}

// x coordinate of the intersection between l1 and l2
long double intersect(line const &l1, line const &l2) {
  return (l2.q - l1.q) / (long double)(l1.m - l2.m);
}

// check if l2 is completely below max(l1, l3)
// requires that l1.m < l2.m < l2.m
bool bad(line const &l1, line const &l2, line const &l3) {
  // or long double if __int128 is not available
  return (__int128)(l2.q - l3.q) * (l2.m - l1.m) <=
         (__int128)(l1.q - l2.q) * (l3.m - l2.m);
}

struct hull {
  multiset<line, less<>> slopes; // less<> to support upper_bound on long long's

  // insert line to hull
  void insert(line const &l) {
    // insert l and then fix the hull until it is convex again
    auto e = slopes.insert(l);

    // delete collinear lines
    if (e != slopes.begin() && prev(e)->m == e->m) {
      slopes.erase(prev(e));
    } else if (next(e) != slopes.end() && e->m == next(e)->m) {
      slopes.erase(e);
      return;
    }

    // delete l again if it is hidden by the lines to the left and the right
    if (e != slopes.begin() && next(e) != slopes.end() &&
        bad(*prev(e), *e, *next(e))) {
      slopes.erase(e);
      return;
    }

    // delete lines to the right of l and adjust their xleft
    if (next(e) != slopes.end()) {
      while (next(e, 2) != slopes.end() && bad(*e, *next(e), *next(e, 2)))
        slopes.erase(next(e));
      next(e)->xleft = intersect(*e, *next(e));
    }

    // delete lines to the left of l and adjust xleft of l
    if (e != slopes.begin()) {
      while (prev(e) != slopes.begin() && bad(*prev(e, 2), *prev(e), *e))
        slopes.erase(prev(e));
      e->xleft = intersect(*e, *prev(e));
    } else {
      e->xleft = -numeric_limits<long double>::infinity();
    }
  }

  // maximum at x:  max_l l.m*x + l.q
  long long query(long long x) {
    assert(!slopes.empty());
    line const& l = *prev(slopes.upper_bound(x)); // upper_bound can never return begin() because it is -inf
    return l.m * x + l.q;
  }
};

Tasks: