The problem

Given an array A, we call (i, j) an important reverse pair if i < j and A[i] > 2*A[j]. Count the number of important reverse pairs.

For this kind of problem, Segment Tree and Binary Indexed Tree seem to be the first choice. However, let's look at a more natural (i.e. easier to understand) solution using a side-effect of Merge Sort.

Merge Sort

Merge Sort was invented by the great John von Neumann in 1945 as a classic example of the Divide and Conquer algorithm. It has some advantages over other sorting algorithms:

  • Merge Sort has an average and worst-case performance of O(n log n). In other words, its performance is stable.
  • Merge Sort is often the best choice for sorting a linked list, in which it only requires O(1) extra space. (Quicksort performs poorly on linked list, Heapsort can't even work on linked list). That means, Merge Sort works very well on slow-to-access sequential media.

The basic idea behind Merge Sort is very natural:

  1. break the array into 2 equal parts
  2. sort each part
  3. merge 2 sorted arrays into one.

merge sort

The key point is in step 3: merging 2 sorted arrays requires less effort: O(n)

vector<long> temp, a;
    
void merge(vector<long> &a, int left, int mid, int right) {
    temp.clear();
    int i = left, j = mid + 1;
    while (i <= mid && j <= right) {
        if (a[i] < a[j]) temp.push_back(a[i++]);
        else temp.push_back(a[j++]);
    }
    while (i <= mid) temp.push_back(a[i++]);
    while (j <= right) temp.push_back(a[j++]);

    for(int i = left; i <= right; i++) a[i] = temp[i-left];
}

void mergeSort(vector<long> &a, int left, int right) {
    if (left == right) return;

    int mid = left + (right - left) / 2;
    mergeSort(a, left, mid);
    mergeSort(a, mid + 1, right);
    merge(a, left, mid, right);
}

In each recursion, before merging 2 sorted arrays, we have:

  • 2 sorted arrays X and Y.
  • original index of all elements in X is less than original index of all elements in Y.

Take advantage of this side-effect, for each index i in X, we can efficiently find the number of index j in Y that X[i] > 2 * Y[j] in O(n) by 2-pointers technique because both X and Y are sorted.

vector<int> temp, a;
int res;
    
void merge(vector<int> &a, int left, int mid, int right) {
    temp.clear();
    int i = left, j = mid + 1;
    while (i <= mid && j <= right) {
        if (a[i] < a[j]) temp.push_back(a[i++]);
        else temp.push_back(a[j++]);
    }

    while (i <= mid) temp.push_back(a[i++]);
    while (j <= right) temp.push_back(a[j++]);

    for(int i = left; i <= right; i++) 
        a[i] = temp[i-left];
}

void mergeSort(vector<int> &a, int left, int right) {
    if (left == right) return;
    int mid = left + (right - left) / 2;
    mergeSort(a, left, mid);
    mergeSort(a, mid + 1, right);
    
    // 2-pointers part:
    int j = mid + 1;
    for(int i = left; i <= mid; i++) {
        while (j <= right && a[j] * 2L < (long long)a[i]) j++;
        res = res + (j - mid - 1);
    }

    merge(a, left, mid, right);
}

int reversePairs(vector<int>& nums) {
    if (nums.size() == 0) return 0;
    res = 0;
    mergeSort(nums, 0, nums.size()-1);
    return res;
}

The run-time is still O(nlogn)

The pattern

Make use of Merge Sort, we can solve similar problems with the similar pattern:

  • Given an array A, count number of pair (i, j)
    such that i < j, A[i] and A[j] fulfil some requirements.
  1. Reverse Pairs

Given an integer array A and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of A[i].

typedef pair<int, int> II;
vector<int> res;
vector<II> a, temp;

void merge(vector<II> &a, int left, int mid, int right) {
    temp.clear();
    int i = left, j = mid + 1;
    while (i <= mid && j <= right) {
        if (a[i].first < a[j].first) temp.push_back(a[i++]);
        else temp.push_back(a[j++]);
    }

    while (i <= mid) temp.push_back(a[i++]);
    while (j <= right) temp.push_back(a[j++]);    
    for(int i = left; i <= right; i++)
        a[i] = temp[i-left];
}

void mergeSort(int left, int right, vector<II> &a) {
    if (left == right) return;

    int mid = (left + right) / 2;
    mergeSort(left, mid, a);
    mergeSort(mid + 1, right, a);
    // 2-pointers part:
    int j = mid + 1;
    for(int i = left; i <= mid; i++) {
        while (j <= right && a[j].first < a[i].first) j++;
        res[a[i].second] += (j-mid-1);
    }

    merge(a, left, mid, right);
}

vector<int> countSmaller(vector<int>& nums) {
    int n = nums.size();    
    for(int i = 0; i < n; i++) {
        a.push_back({ nums[i], i });
    }

    res = vector<int>(n, 0);
    mergeSort(0, n-1, a);

    return res;
}
  1. Count the Range Sum

Given an integer array A, return the number of range sums that lie in [lower, upper] inclusive.
Range sum S(i, j) is defined A[i] + A[i+1] + ... + A[j-1] + A[j]

Let's define S[i] = A[0] + A[1] + ... + A[i]. So S(i,j) = S[j] - S[i-1]. The problem become: Given an array S, count the number of pair (i,j) such that: i < j and S[j] - S[i] lie in [lower, upper]

int res = 0, low, high;
vector<long> temp, a;

void merge(vector<long> &a, int left, int mid, int right) {
    temp.clear();
    int i = left, j = mid + 1;
    while (i <= mid && j <= right) {
        if (a[i] < a[j]) temp.push_back(a[i++]);
        else temp.push_back(a[j++]);
    }
    while (i <= mid) temp.push_back(a[i++]);
    while (j <= right) temp.push_back(a[j++]);

    for(int i = left; i <= right; i++) a[i] = temp[i-left];
}

void mergeSort(vector<long> &a, int left, int right) {
    if (left == right) return;

    int mid = left + (right - left) / 2;
    mergeSort(a, left, mid);
    mergeSort(a, mid + 1, right);
    
    // 2-pointers part:
    int j1 = mid + 1, j2 = mid + 1;
    for(int i = left; i <= mid; i++) {
        while (j1 <= right && a[j1] - a[i] < low) j1++;
        while (j2 <= right && a[j2] - a[i] <= high) j2++;

        res+= (j2-j1);
    }

    merge(a, left, mid, right);
}

int countRangeSum(vector<int>& nums, int lower, int upper) {
    if (nums.size() == 0) return 0;
    low = lower;
    high = upper;
    int n = nums.size();
    a = vector<long>(n+1);
    a[0] = 0;
    a[1] = nums[0];
    for(int i = 1; i < n; i++) 
        a[i+1] = a[i] + nums[i];

    mergeSort(a, 0, a.size() - 1);
    return res;
}