Smaller to Larger (DSU on trees)

Written by Joël Benjamin Huber. Translated by Joël Benjamin Huber.

Trees are very special kind of graph and many procedures on trees can be optimized. One of these optimizations is called “Smaller to Larger”, also known as “Dsu on Trees”. We will later see why it is called “Smaller to Larger”.

Introductory Problem

Given a rooted tree with colored vertices, find for each vertex vv the number of vertices in the subtree at vv with the same color as vv.

%3 0 4 1 2 0--1 4 1 0--4 5 2 0--5 2 1 1--2 3 1 1--3 6 1 5--6 7 1 5--7 8 1 5--8

Figure 1: An example graph for the introductory problems. The labels correspond to the solutions for the vertices.

Problems with short problem statements tend to be easy, right? The naive approach to this problem would be to start a DFS for each vertex, counting the number of vertices in the subtree with the same color. This runs in O(n2)\mathcal O(n^{2}). Let’s take a closer look at what we’re doing. It’s often useful to look at the information we discard while calculating or what we’re calculating several times. Looking at what our algorithm is doing, we can see that we go through each subtree several times again and again. Suppose we know for a vertex vv and for every color the number of vertices with that color in the subtree rooted at vv. Then there’s no need to go through the whole subtree again. So we get the idea that instead of doing a DFS from each vertex, we save for each vertex the information about the subtree of this vertex and to get this information for a vertex, we somehow merge the information of the children together. In this case, we could create a map from colors to the number of vertices with this color for each subtree. We can create a map for a vertex by copying the elements from the maps from the children into a new map and then add the color of the new vertex.

Note that the size of the maps is bounded from above by the number of elements in the subtree. The problem is, that we can have O(n)\mathcal O(n) elements in both of the maps, und since for each vertex, we copy the elements into a new map, our running time is O(number of vertices)O(sizes of the maps)O(time to insert into map)=O(n)O(n)O(logn)=O(n2logn)\mathcal O(\text{number of vertices})* \mathcal O(\text{sizes of the maps})* \mathcal O(\text{time to insert into map}) = \mathcal O(n)* \mathcal O(n)* \mathcal O(logn) = \mathcal O(n^{2}logn), which is even worse than what we had before. The advantage is now, that we don’t worry about going through the subtrees anymore, the “slowness” of the algorithm lies now in the way we copy the elements. But we can speed this up. Let’s see how.

Instead of creating a new map for each vertex, we can keep one of the maps of the children and copy the other maps into this one. Which map should we choose to keep? Intuitively, we should keep the largest one since this will minimize the number of elements we need to copy. And this will work and reduce our total running time. This is the trick named smaller to larger: When dealing with some kind of tree structure and we need to merge sets, when - for each merge operation - keep the large set and copy the elements of the smaller set into the larger set, then our running time will be reduced. Let nn be the number of vertices and mm be the number of elements. Then our running time is O(n+mlog(m)(time to insert element into set))\mathcal O(n + m*log(m)*(\text{time to insert element into set})). In most of the time, mO(n)m \in \mathcal O(n), and we use the normal set::insert() ``/ `` map::operator[] functions, we get a running time of O(n  log2(n))\mathcal O(n \; log^{2}(n)). Let’s first see why this makes sense intuitively, before going on to a formal proof.

Intuition

For sake of simplicity, assume our tree is a binary tree. Let’s fix a vertex vv. If we now merge the maps of the two children arbitrarily, our worst-case is merging a very big into a very small map. Now if we’re clever, we check which map is smaller and merge the smaller map into the larger one. Then our new worst-case is when the tree is balanced, otherwise we save some operations. But when the tree is balanced, the height is about O(log(n))\mathcal O(log(n)) and so we copy each element only about O(log(n))\mathcal O(log(n)) times. Combining this, we should get something around O(n)O(log(n))O(log(n))=O(n  log2(n))\mathcal O(n)* \mathcal O(log(n))* \mathcal O(log(n)) = \mathcal O(n \; log^{2}(n)). This is actually the final time complexity of this algorithm; however, this “intuition” is very far from a correct proof. So, let’s get to the proof.

Formal Proof

For the formal proof, we change the structure a bit. Suppose we have nn sets initially, with a total of mm elements in them. Let’s define an operation onto a set to add a new element to the set, let’s say S.add_element(e) adds element e to S. We also need to add the restrictions that add_element increases the cardinality of the set by at most one. For example, the normal insert() operation is a valid operation for add_element. Or in our case, we can make our elements pairs of integers {color, number of vertices with this color}, which can be coded easier by using a map. We can then choose our add_element() operation to be the following:

def add_element(col, cnt):
    if ({col, k} already in the set for some k):
        replace {col, k} by {col, k + cnt}
    else:
        insert {col, cnt} into the set

This can be done with C++ maps in a very easy way:

void add_element(map<int,int>& mp, pair<int,int> p)
{
    mp[p.first] += p.second;
}

Note that this operation is both associative and increases the cardinality of the set by at most one. Now we want to do the following: We want to repeatedly do the following operation until only one set is left: Choose some sets, and replace them with one set S constructed by:

def merge(set S_0, set S_1, ...)
    S = set of S_0, S_1, ... with largest cardinality

    for (each set S_0, S_1, ... not initially chosen as S):
        for (each element e in current set):
            S.add_element(e);

    return S

Note that this is what we were doing before: In the beginning, each vertex is a single set, and then we start merging the sets according to the tree structure, which can be arbitrary. Let’s say that if add_element(e) increases the cardinality of the new set by one, then the new element is a successor of e. Now let e be an element in any of the n initial sets. Let’s look how many times we apply add_element with e or an (direct or indirect) successor of e as argument. Note that since we discard the “old” sets after we merge the sets, there is always at most one successor of e (or e itself) alive at the moment.

Note that if add_element(e) does not increase add an element to the target set, then there is no successor of e anymore and there will never be. So we will never merge e into a new set anymore, thus reducing the total number of moves we make. So we can assume that add_element(e) always adds exactly one element, since this will always be worse for us.

Let’s now look at how many times we use add_element with e or a successor of e as argument. To be able to bound the number of times we do this, we need an invariant: The cardinality of the set the current successor of e (or e itself) is in. Why should we choose this as an invariant? Because we somehow need to use that we always merge into the largest set.

Now denote the set our element is currently in as EE, the largest set of the ones we merge together as LL, and the set we get in the end as RR. Note that we don’t use add_element when our element is in the largest set, because we keep the largest set as it is. Writing down some inequalities, we get RL+E+|R| \ge |L| + |E| + some rest L+E2E\ge |L| + |E| \ge 2|E|. This is very useful: Every time we use add_element with e or a successor of e, the cardinality of the set the next successor is in is at least twice as big as the set of the old element. We can use this to bound the number of calls to add_element, because the cardinality of the set cannot exceed mm (The total number of elements cannot increase, since at most one successor of each element can be “alive”). Thus, we can write this bound: Let cc be the number of calls to add_element with e or a successor of e as element: Then 2cm2^c \le m must hold. Taking the logarithm on both sides, we get that clog(m)c \le log(m). Now each element in any of the initial set adds at most log(m)log(m) calls to add_element, so the total time we need for all the calls to add_element is \mathcal O(m)*\mathcal O(log(m))*\mathcal O(\text{time for add_element}). We should not forget to add O(n)\mathcal O(n) because we need look at each set, but normally the time we need for merging dominates the total running time. Most of the time, mO(n)m \in \mathcal O(n) and (\text{time for add_element}) \; \in \; \mathcal O(log(n)), which gets us a runtime of O(n  log2(n))\mathcal O(n\;log^{2}(n)).

Sample implementation

Here you can find an implementation of Smaller to Larger, which (instead of merging all the children together in one merging process) merges the sets of the children one by one. This implementation counts the number of distinct colors in each subtree. What do you need to modify such that it solves our introductory problem?

struct graph
{
    vector<vector<int>> adjlist;
    vector<int> col;
    vector<int> sol;

    set<int> dfs(int curr, int par)
    {
        set<int> cst(); // Create a new set
        cst.insert(col[curr]); // Insert color of this vertex

        for (auto next : adjlist[curr])
        {
            if (next == par) continue;

            auto nst = dfs(next, curr); // Get the set of this child
            if (nst.size() > cst.size()) swap(nst, cst);
            // If the new set is smaller, swap it such that we copy into the larger set

            for (auto p : nst)
            {
                cst.insert(p); // Copy the values from the smaller to the larger set
            }
        }

        sol[curr] = cst.size(); // Save the solution for this vertex
        return cst; // Return the current set.
    }
};

Sample Problems

BOI 2017 Railway

Disclaimer: With upwards, I mean in the direction to the root, with downwards away from the root.

The ministry of infrastructuer in Bergen wants to build a new railway in order to connect the nn stations. They only wanted to build n1n - 1 connecting tracks such that all stations are connected. However, they found out they can not build all of the tracks in time. So they decided to ask mm ministers which tracks they think should be built. Each minister wrote down a list of neighbourhoods they think should be connected: The ii-th minister wrote down sis_i vertices ai,0,ai,1,,ai,si1a_{i, 0}, a_{i, 1}, {\dots}, a_{i, s_i - 1}. For each pair of vertices in his list, he wants all the tracks on the direct path between these two vertices to be built. The ministry of infrastructure then decided to construct all the tracks that are requested by at least kk ministers. Our task is to figure out which of the tracks should be built. The limits in the problem are n100000n \leq 100\,000, km50000k \leq m \leq 50\,000 and isi100000\sum_{i}{s_i} \leq 100\,000.

In this task, we need to exploit the tree-structure several times. First, we may try to look how the tracks a minister wants to build look like. Before reading on, the reader should take a paper and try to find some nice way to describe these tracks.

In such problems, it’s always useful to root the tree. After looking at some examples, it seems like LCALCA plays an important role. Indeed, it seems like the minister thinks a track should be built if and only if it lies on the direct path between any selected station and the LCALCA of all selected stations. And this is not hard to prove with some basic facts about LCAsLCAs and the fact that each path from aa to bb in the tree can be split up in a path upwards from aa to LCA(a,b)LCA(a, b) and a path downwards from LCA(a,b)LCA(a, b) to bb.

Now, since we understood how the selected tracks from each minister look like, we somehow need to figure out how we can process the tracks efficiently. Suppose a minister selected a station aa. We can get all the tracks this station adds, by starting at this vertex and always moving upwards and “marking” the visited edges until we reach the LCALCA and stop there. This already smells a bit like smaller to larger. Let’s say that a minister is active on a vertex vv if the edge from vv upwards is selected by this minister. We can see that this is the case [if the same minister is active on a child of vv and vv or the minister selected this station] and [vv is not the LCALCA of all the vertices the minister selected]. This is solvable with smaller to larger: We try to calculate for each vertex the set of active ministers. With smaller to larger, we can merge the sets of the children efficiently. We only need to store if a vertex is selected by a minister, then we need to insert that minister, or if a station is the LCALCA of the stations of a minister, we need to take this minister out of the set (after merging). For each edge, we can easily calculate the number of ministers who want to build this edge: It’s just the number of active ministers on the vertex directly below the edge.

BOI 2017 Railway (Expand)Copy
#include <iostream>
#include <vector>
#include <algorithm>
#include <set>

#define int int64_t
#define log2(x) (8*sizeof (long long) - __builtin_clzll(x))

using namespace std;

struct tree
{
   int n;
   vector<vector<int>> adjlist;
   vector<vector<int>> st;
   vector<vector<int>> en;
   vector<int> res;

   set<int>* smldfs(int curr, int par) // Our smaller-to-larger dfs
   {
      // Create new set with the newly active minister
      set<int>* rs = new set<int>(st[curr].begin(), st[curr].end());

      for (auto next : adjlist[curr])
      {
         if (next == par) continue;

         // Get the next set and look that we merge into the larger set
         set<int>* ot = smldfs(next, curr);
         if (ot->size() > rs->size()) swap(ot, rs);

         // Merge the sets
         for (auto i : (*ot))
         {
            rs->insert(i);
         }
      }

      // Erase the ministers that are not active anymore
      for (auto i : en[curr])
      {
         rs->erase(i);
      }

      // Save the number of different ministers active at this vertex
      res[curr] = rs->size();

      // Return the set
      return rs;
   }

   vector<vector<int>> table;
   vector<int> lvl;
   int logn;

   // The dfs to prepare our LCA
   void dfs(int curr, int par, int lv)
   {
      lvl[curr] = lv;
      table[0][curr] = par;
      for (auto next : adjlist[curr]) if (next != par) dfs(next, curr, lv + 1);
   }

   // Precalculate for the LCA
   void prelca()
   {
      logn = log2(n) + 1;
      table.resize(logn, vector<int>(n));
      lvl.resize(n);
      dfs(0, 0, 0);

      // Create the Binary-Lifting table
      for (int i = 1; i < logn; i++)
      {
         for (int j = 0; j < n; j++)
         {
            table[i][j] = table[i - 1][table[i - 1][j]];
         }
      }
   }

   // Calculate the lca
   int lca(int a, int b)
   {
      if (lvl[a] < lvl[b]) swap(a, b);
      int ld = lvl[a] - lvl[b];

      for (int j = 0; j < logn; j++)
      {
         if (ld & (1<<j)) a = table[j][a];
      }
      if (a == b) return a;
      for (int j = logn - 1; j >= 0; j--)
      {
         if (table[j][a] != table[j][b])
         {
            a = table[j][a];
            b = table[j][b];
         }
      }
      return table[0][a];
   }
};

signed main()
{
   // I/O stuff
   ios_base::sync_with_stdio(0);
   cin.tie(0);

   int n, m, k;
   cin >> n >> m >> k;

   tree t;
   t.n = n;
   t.adjlist.resize(n);
   vector<pair<int,int>> edges;

   for (int i = 0; i < n - 1; i++)
   {
      int a, b;
      cin >> a >> b;
      a--; b--;
      edges.emplace_back(a, b);
      t.adjlist[a].push_back(b);
      t.adjlist[b].push_back(a);
   }

   // Precalculate LCA
   t.prelca();
   t.st.resize(n);
   t.en.resize(n);

   // Read the vertices from the ministers
   for (int i = 0; i < m; i++)
   {
      int c;
      cin >> c;
      int l = -1;
      for (int j = 0; j < c; j++)
      {
         int a;
         cin >> a;
         a--;
         // Mark this vertex as a point where a minister gets active
         t.st[a].push_back(i);
         if (l == -1) l = a;
         else l = t.lca(l, a);
      }

      // Mark this vertex as the vertex where the minister gets inactive
      t.en[l].push_back(i);
   }

   // Run the dfs
   t.res.resize(n);
   t.smldfs(0, -1);
   // The number of active ministers for each vertex is now saved in t.res

   // Check all the edges
   vector<int> sol;
   for (int i = 0; i < n - 1; i++)
   {
      auto e = edges[i];
      int a = (t.lvl[e.first] > t.lvl[e.second] ? e.first : e.second);
      if (t.res[a] >= k) sol.push_back(i + 1);
   }

   cout << sol.size() << "\n";
   for (auto i : sol) cout << i << " ";
   cout << "\n";
}

CEOI 2019 Magictree

In this problem, we’re given a rooted tree. Some vertices of this tree have fruits. Each fruit ii is ripe for one day did_i and has a juicy-value of jij_i. We can harvest a ripe fruit by cutting off the vertex with the fruit is on. But then all the fruits we didn’t harvest which are not connected to the root anymore die. Our goal is to maximize the total juicyness of the fruit we harvest

This looks loke some kind of dp. Indeed, we can define dp[v][i]dp[v][i] as the maximal juiciness we can harvest from the subtree rooted at vv in the first ii days. The recursion formula is not hard to come up with, when we look at the 2 cases (cut and no cut):

dpv,i={max(u is achild of vdp[u][i],  ju+u is achild of vdp[u][dv])if dviu is achild of vdp[u][i]otherwisedp_{v,i} = \begin{cases} \max( \sum\limits_{\substack{\text{$u$ is a} \\ \text{child of $v$}}}{ dp[u][i]} ,\; j_u + \sum\limits_{\substack{\text{$u$ is a} \\ \text{child of $v$}}}{ dp[u][d_v]} ) & \text{if $d_v \leq i$} \\ \\ \sum\limits_{\substack{\text{$u$ is a} \\ \text{child of $v$}}}{ dp[u][i]} & \text{otherwise} \end{cases}

Our dp has O(nmax(di))\mathcal O(n*max(d_i)) states, and this is also the final running time of this algorithm (What happens to the transitions?). But need to do better. Clearly, we want to get rid of the factor max(di)max(d_i). So what is missing? The observations.

Let’s plot the dp-array of a vertex. We can make 2 observations: The array consists of several constant segments and it is non-decreasing. The second is obvious: Since we’re looking at the first ii days, if we look at the first j>ij > i days, we could potentially do the same cuts as in the first ii days if there is no better option. For the second one, we can see that the interesting times are only the ones where a fruit is ripe. If at day i+1i + 1, there is no new ripe fruit, we still have the same possibilities as on day ii. Thus we can compress the times to get a O(n2)\mathcal O(n^{2}) solution.

https://i.imgur.com/bPWt2el.png

Figure 2: Plot of the dp of a vertex over the time. The red line marks the values which takes the fruit at vertex vv.

But instead of compressing, we just change the representation of our dp. Instead of saving the values, we save the differences between the values. You can see this as some kind of the inverse to the prefix-sum: If we do prefix-sums on our new array, we get the original dp back. Note that we lose the ability to look at a value of our dp in O(1)\mathcal O(1) but we don’t care since we need the values once, namely of the root, so we have plenty of time to calculate the values then. Since in this array, most of the values are zeroes, we discard them and store the rest in a map.

What about transitions? Note that summing up the original dp-arrays (u is achild of vdp[u][i]\sum\limits_{\substack{\text{$u$ is a} \\ \text{child of $v$}}}{ dp[u][i]}) is equivalent to summing up our new changes-arrays. Thus we can just merge our maps and add the elements with the same key together. Hey, we’re merging maps! Does this ring a bell? It should! We can merge smaller maps into larger ones. This will reduce our running time.

This is however not the whole transition formula. However, we can notice (looking at the graph again), the formula just adds another constant segment into the sum we calculated, which can be added with an entry at position dvd_v with value jvj_v. But if we only add this entry, we need to subtract in total jvj_v from the next elements, since we don’t want to increase everything that will come later, we only want to add this new segment. So we subtract from the next entries until we subtracted jvj_v in total, deleting empty entries. Note that since we insert each element only once, we can erase it onlly once, and so we delete amortized O(n)\mathcal O(n) elements in total.

We only insert one element per vertex, so we get with smaller to larger a total running time of O(n  log2(n))\mathcal O(n\;log^{2}(n))

CEOI 2019 Magictree (Expand)Copy
#include <iostream>
#include <vector>
#include <map>

#define int int64_t

using namespace std;

struct graph
{
   vector<vector<int>> adjlist;
   vector<int> ds;
   vector<int> js;

   map<int,int>* dfs(int curr)
   {
      auto curr_map = new map<int,int>(); // Create a new map-pointer

      for (auto next : adjlist[curr])
      {
         auto new_map = dfs(next); // Get map of child

         // Make sure that we merge into the smaller map
         if (curr_map->size() < new_map->size()) swap(curr_map, new_map);

         // Merge the maps
         for (auto p : *new_map)
         {
            (*curr_map)[p.first] += p.second;
         }
      }

      // Insert the element of the current vertex
      (*curr_map)[ds[curr]] += js[curr];
      auto curr_el = curr_map->find(ds[curr]);

      int left = js[curr];

      // Subtract from the following entries
      while (left > 0 && next(curr_el) != curr_map->end())
      {
         auto next_el = next(curr_el); // Next element

         if (left >= next_el->second)
         {
            // Remove next element
            left -= next_el->second;
            curr_map->erase(next_el);
         }
         else
         {
            curr_map->at(next_el->first) -= left;
            break; // We subtracted enough
         }
      }

      return curr_map;
   }
};

signed main()
{
   // I/O things
   ios_base::sync_with_stdio(0);
   cin.tie(0);

   int n, m, k;
   cin >> n >> m >> k;

   graph g;
   g.adjlist.resize(n);
   g.ds.resize(n, 1e16);
   g.js.resize(n, 0);

   for (int i = 1; i < n; i++)
   {
      int p;
      cin >> p; p--;
      g.adjlist[p].push_back(i);
   }

   for (int i = 0; i < m; i++)
   {
      int v;
      cin >> v; v--;
      cin >> g.ds[v] >> g.js[v];
   }

   // Calculate the result

   auto fm = *g.dfs(0);
   int res = 0;

   for (auto p : fm)
   {
      res += p.second;
   }

   cout << res << "\n";
}