跳转至

二维树状数组

二维树状数组,也被称作树状数组套树状数组,用来维护二维数组上的单点修改和前缀信息问题。其实就是一维树状数组上的每个节点变成了一个一维树状数组:

Image title

二维树状数组可以实现单点修改,子矩阵查询,也可以通过维护多个差分数组来实现子矩阵的修改和查询。

单点修改

修改一个矩阵的值后,我们要将其父节点都修改,例如我们要修改上图的 \(f[2][2]\) 元素,那么我们同时也要修改 \(f[2][5]、f[5][2]、f[5][5]\) ,这些节点,实现如下:

void update(int x, int y, int d)
{
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            f[i][j] += d;
}

时间复杂度为 \(O(\log n\log m)\)

子矩阵查询

现在我们可以求 \(\sum_{i=1}^{n}\sum_{j=1}^{m}a[i][j]\) 的值,要我们计算下图中黑色矩阵的值

Image title

很明显,计算黑色矩阵的面积有下面的公式:

\[ \sum_{i=x_2}^{x_1}\sum_{j=y_2}^{y_1}a[i][j]=\sum_{i=1}^{x_1}\sum_{j=1}^{y_1}a[i][j]-\sum_{i=1}^{x_2-1}\sum_{j=1}^{y_1}a[i][j]-\sum_{i=1}^{x_1}\sum_{j=1}^{y_2-1}a[i][j]+\sum_{i=1}^{x_2}\sum_{j=1}^{y_2}a[i][j] \]

有了公式代码也就很容易写了:

int ask(int x, int y)
{
    int sum = 0;
    for (int i = x; i > 0; i -= lowbit(i))
        for (int j = y; j > 0; j -= lowbit(j))
            sum += f[i][j];
    return sum;
}
int ask(int x1, int y1, int x2, int y2)
{
    return ask(x1, y1) - ask(x1, y2 - 1) - ask(x2 - 1, y1) + ask(x2, y2);
}

时间复杂度为\(O(\log n\log m)\)

二维树状数组进阶

上帝造题的七分钟 该题要我们实现二维树状数组的子矩阵修改和子矩阵查询

子矩阵修改

和一维类似的,这里也要用到差分的思想,使用二维差分,我们定义一个二维差分数组 \(d[i][j]\),它与原矩阵元素 \(a[i][j]\) 的关系如下:

\[ d[i][j]=a[i][j]-a[i][j-1]-a[i-1][j]+a[i-1][j-1]\\ \]
\[ a[x][y]=\sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j] \]

现在,假设我们要在顶点为 \((a,b)、(c,d)\) 的矩阵内的每个元素加上 \(k\) ,如下图

Image title

因为我们维护的是差分数组,所以我分别修改四个顶点即可完成对区间的修改,实现如下:

void update(int x, int y, int d)
{
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            f[i][j] += d;
}
void update(int x1, int y1, int x2, int y2, int k)
{
    update(x1, y1, k), update(x1, y2 + 1, -k),
        update(x2 + 1, y1, -k), update(x2 + 1, y2 + 1, k);
}

子矩阵查询

与一维数组类似的,我们求一个子矩阵的和有下面的式子:

\[ \sum_{i=x_2}^{x_1}\sum_{j=y_2}^{y_1}a[i][j]=\sum_{i=1}^{x_1}\sum_{j=1}^{y_1}a[i][j]-\sum_{i=1}^{x_2-1}\sum_{j=1}^{y_1}a[i][j]-\sum_{i=1}^{x_1}\sum_{j=1}^{y_2-1}a[i][j]+\sum_{i=1}^{x_2}\sum_{j=1}^{y_2}a[i][j] \]

问题就转换为计算 \(\sum_{i=1}^{n}\sum_{j=1}^{m}a[i][j]\) ,与推一维数组区间查询类似,我们利用原数组的与差分数组之间的关系进行变换,下面是推导过程:

\[ \begin{align*} &\sum_{i=1}^{n}\sum_{j=1}^{m}a[i][j]\\ = &\sum_{i=1}^{n}\sum_{j=1}^{m}\sum_{k=1}^{i}\sum_{l=1}^{j}d[k][l]\\ =&\sum_{i=1}^{n}\sum_{j=1}^{m}d[i][j]\times (n-i+1)\times (m-j+1)\\ =&(n+1)(m+1)\sum_{i=1}^{n}\sum_{j=1}^{m}d[i][j]-(m+1)\sum_{i=1}^{n}\sum_{j=1}^{m}d[i][j]\times i\\ &(n+1)\sum_{i=1}^{n}\sum_{j=1}^{m}d[i][j]\times j+\sum_{i=1}^{n}\sum_{j=1}^{m}d[i][j]\times i\times j \end{align*} \]

我们可以用四个二维树状数组来分别维护上面四个二维前缀和,下面是实现:

int ask(int x, int y)
{
    int sum = 0;
    for (int i = x; i > 0; i -= lowbit(i))
        for (int j = y; j > 0; j -= lowbit(j))
        {
            sum += (x + 1) * (y + 1) * f[0][i][j] - 
                (y + 1) * f[1][i][j] - (x + 1) * f[2][i][j] + f[3][i][j];
        }
    return sum;
}

int ask(int x1, int y1, int x2, int y2)
{
    return ask(x2, y2) - ask(x1 - 1, y2) - ask(x2, y1 - 1) + ask(x1 - 1, y1 - 1);
}
2_BIT
#include <bits/stdc++.h>
using namespace std;
const int MAX = 1e3 + 100;

inline int lowbit(int x) { return x & -x; }

// n代表二维数组的行,m代表二维树状数组的列
// f[]数组代表维护的数组
int n, m;
int f[MAX][MAX];

// 基础二维树状数组
// 支持单点修改,子矩阵查询
void update(int x, int y, int d)
{
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            f[i][j] += d;
}

// 二维前缀和查询
int ask(int x, int y)
{
    int sum = 0;
    for (int i = x; i > 0; i -= lowbit(i))
        for (int j = y; j > 0; j -= lowbit(j))
            sum += f[i][j];
    return sum;
}
// 子矩阵查询
int ask(int x1, int y1, int x2, int y2)
{
    return ask(x1, y1) - ask(x1, y2 - 1) - ask(x2 - 1, y1) + ask(x2, y2);
}

// 4*差分数组+二维树状数组
// 支持子矩阵修改,子矩阵查询
int f2[4][MAX][MAX];
// 子矩阵修改
void update(int x, int y, int d)
{
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            f2[0][i][j] += d, f2[1][i][j] += d * x,
                f2[2][i][j] += y * d, f2[3][i][j] += x * y * d;
}
// 维护子矩阵和
void update(int x1, int y1, int x2, int y2, int k)
{
    update(x1, y1, k), update(x1, y2 + 1, -k),
        update(x2 + 1, y1, -k), update(x2 + 1, y2 + 1, k);
}

// 前缀和查询
int ask(int x, int y)
{
    int sum = 0;
    for (int i = x; i > 0; i -= lowbit(i))
        for (int j = y; j > 0; j -= lowbit(j))
            sum += (x + 1) * (y + 1) * f2[0][i][j] - (y + 1) * f2[1][i][j] - (x + 1) * f2[2][i][j] + f2[3][i][j];
    return sum;
}
// 查询子矩阵和
int ask(int x1, int y1, int x2, int y2)
{
    return ask(x2, y2) - ask(x1 - 1, y2) - ask(x2, y1 - 1) + ask(x1 - 1, y1 - 1);
}