更新数组后处理求和查询

标签: 线段树 数组

难度: Hard

给你两个下标从 0 开始的数组 nums1 和 nums2 ,和一个二维数组 queries 表示一些操作。总共有 3 种类型的操作:

  1. 操作类型 1 为 queries[i] = [1, l, r] 。你需要将 nums1 从下标 l 到下标 r 的所有 0 反转成 1 并且所有 1 反转成 0 。l 和 r 下标都从 0 开始。
  2. 操作类型 2 为 queries[i] = [2, p, 0] 。对于 0 <= i < n 中的所有下标,令 nums2[i] = nums2[i] + nums1[i] * p 。
  3. 操作类型 3 为 queries[i] = [3, 0, 0] 。求 nums2 中所有元素的和。

请你返回一个 数组,包含 所有第三种操作类型 的答案。

示例 1:

输入:nums1 = [1,0,1], nums2 = [0,0,0], queries = [[1,1,1],[2,1,0],[3,0,0]]
输出:[3]
解释:第一个操作后 nums1 变为 [1,1,1] 。第二个操作后,nums2 变成 [1,1,1] ,所以第三个操作的答案为 3 。所以返回 [3] 。

示例 2:

输入:nums1 = [1], nums2 = [5], queries = [[2,0,0],[3,0,0]]
输出:[5]
解释:第一个操作后,nums2 保持不变为 [5] ,所以第二个操作的答案是 5 。所以返回 [5] 。

提示:

  • 1 <= nums1.length,nums2.length <= 105
  • nums1.length = nums2.length
  • 1 <= queries.length <= 105
  • queries[i].length = 3
  • 0 <= l <= r <= nums1.length - 1
  • 0 <= p <= 106
  • 0 <= nums1[i] <= 1
  • 0 <= nums2[i] <= 109

Submission

运行时间: 423 ms

内存: 44.3 MB

class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        cnt1 = [0] * (4 * n)
        flip = [False] * (4 * n)

        # 维护区间 1 的个数
        def maintain(o: int) -> None:
            cnt1[o] = cnt1[o * 2] + cnt1[o * 2 + 1]

        # 执行区间反转
        def do(o: int, l: int, r: int) -> None:
            cnt1[o] = r - l + 1 - cnt1[o]
            flip[o] = not flip[o]

        # 初始化线段树   o,l,r=1,1,n
        def build(o: int, l: int, r: int) -> None:
            if l == r:
                cnt1[o] = nums1[l - 1]
                return
            m = (l + r) // 2
            build(o * 2, l, m)
            build(o * 2 + 1, m + 1, r)
            maintain(o)

        # 反转区间 [L,R]   o,l,r=1,1,n
        def update(o: int, l: int, r: int, L: int, R: int) -> None:
            if L <= l and r <= R:
                do(o, l, r)
                return
            m = (l + r) // 2
            if flip[o]:
                do(o * 2, l, m)
                do(o * 2 + 1, m + 1, r)
                flip[o] = False
            if m >= L: update(o * 2, l, m, L, R)
            if m < R: update(o * 2 + 1, m + 1, r, L, R)
            maintain(o)

        build(1, 1, n)
        ans, s = [], sum(nums2)
        for op, l, r in queries:
            if op == 1: update(1, 1, n, l + 1, r + 1)
            elif op == 2: s += l * cnt1[1]
            else: ans.append(s)
        return ans
    

Explain

这个题解使用了线段树来处理区间反转和求和的操作。线段树是一种适用于区间操作和区间查询的二叉树结构,能高效地更新和查询数据。通过维护额外的翻转标记(flip)来实现区间翻转操作。对于每个操作,根据操作的类型执行不同的动作:类型1的操作是区间翻转,类型2的操作是根据线段树的根节点统计值更新总和,类型3则是将当前的总和添加到结果列表中。

时间复杂度: O(n + q * log n)

空间复杂度: O(n)

class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        cnt1 = [0] * (4 * n)  # 线段树节点,存储区间1的个数
        flip = [False] * (4 * n)  # 线段树节点,存储是否需要翻转

        # 维护区间 1 的个数
        def maintain(o: int) -> None:
            cnt1[o] = cnt1[o * 2] + cnt1[o * 2 + 1]

        # 执行区间反转
        def do(o: int, l: int, r: int) -> None:
            cnt1[o] = r - l + 1 - cnt1[o]
            flip[o] = not flip[o]

        # 初始化线段树
        def build(o: int, l: int, r: int) -> None:
            if l == r:
                cnt1[o] = nums1[l - 1]
                return
            m = (l + r) // 2
            build(o * 2, l, m)
            build(o * 2 + 1, m + 1, r)
            maintain(o)

        # 反转区间 [L,R]
        def update(o: int, l: int, r: int, L: int, R: int) -> None:
            if L <= l and r <= R:
                do(o, l, r)
                return
            m = (l + r) // 2
            if flip[o]:
                do(o * 2, l, m)
                do(o * 2 + 1, m + 1, r)
                flip[o] = False
            if m >= L: update(o * 2, l, m, L, R)
            if m < R: update(o * 2 + 1, m + 1, r, L, R)
            maintain(o)

        build(1, 1, n)
        ans, s = [], sum(nums2)
        for op, l, r in queries:
            if op == 1: update(1, 1, n, l + 1, r + 1)
            elif op == 2: s += l * cnt1[1]
            else: ans.append(s)
        return ans

Explore

在初始化线段树时,将`nums1[l - 1]`赋给`cnt1[o]`而不是`nums1[l]`是因为线段树构建函数`build`的参数`l`和`r`是从1开始的,而数组`nums1`的下标是从0开始的。因此,当访问`nums1`的元素时,我们需要使用`nums1[l - 1]`来正确地映射到`nums1`数组的元素。这样做是为了在逻辑上保持线段树节点与数组索引的一致性。

在`do`函数中执行`cnt1[o] = r - l + 1 - cnt1[o]`是为了正确地计算区间内的翻转。这里,`r - l + 1`代表区间内元素的总数。通过从区间总元素数中减去当前记录的`1`的数量`cnt1[o]`,我们可以得到区间内`0`的数量,这恰好是翻转后`1`的数量。因此,该操作能正确地反映区间内0和1的翻转,更新`1`的计数。

为了确保在处理区间反转时不会漏掉或重复执行翻转,特别是在处理部分重叠的区间时,我们采用了延迟传播(懒惰传播)的技术。当一个节点需要被翻转时,我们并不立即更新它的所有子节点,而是在`flip`数组中标记该节点为需要翻转。当实际需要访问这些子节点时(例如进行查询或进一步更新),我们才将翻转操作向下传播到子节点,并且清除当前节点的翻转标记。这种方法确保了每个节点的翻转操作只在必要时执行,从而避免了重复或遗漏。