子数组不同元素数目的平方和 II

标签: 树状数组 线段树 数组 动态规划

难度: Hard

给你一个下标从 0 开始的整数数组 nums 。

定义 nums 一个子数组的 不同计数 值如下:

  • 令 nums[i..j] 表示 nums 中所有下标在 ij 范围内的元素构成的子数组(满足 0 <= i <= j < nums.length ),那么我们称子数组 nums[i..j] 中不同值的数目为 nums[i..j] 的不同计数。

请你返回 nums 中所有子数组的 不同计数 的 平方 和。

由于答案可能会很大,请你将它对 109 + 7 取余 后返回。

子数组指的是一个数组里面一段连续 非空 的元素序列。

示例 1:

输入:nums = [1,2,1]
输出:15
解释:六个子数组分别为:
[1]: 1 个互不相同的元素。
[2]: 1 个互不相同的元素。
[1]: 1 个互不相同的元素。
[1,2]: 2 个互不相同的元素。
[2,1]: 2 个互不相同的元素。
[1,2,1]: 2 个互不相同的元素。
所有不同计数的平方和为 12 + 12 + 12 + 22 + 22 + 22 = 15 。

示例 2:

输入:nums = [2,2]
输出:3
解释:三个子数组分别为:
[2]: 1 个互不相同的元素。
[2]: 1 个互不相同的元素。
[2,2]: 1 个互不相同的元素。
所有不同计数的平方和为 12 + 12 + 12 = 3 。

提示:

  • 1 <= nums.length <= 105
  • 1 <= nums[i] <= 105

Submission

运行时间: 3006 ms

内存: 87.3 MB

class Solution:
    def sumCounts(self, nums: List[int]) -> int:
        n = len(nums)
        arr = [0] * (4 * n)
        lazy = [0] * (4 * n)

        def do(o,l,r,op):
            arr[o] += op * (r-l+1)
            lazy[o] += op

        def query(o, l, r, L, R) -> int:
            if L <= l and r <= R:
                ans = arr[o]
                do(o, l, r, 1)
                return ans

            mid = (l + r) // 2

            if lazy[o] != 0:
                do(o * 2, l, mid,lazy[o])
                do(o * 2 + 1, mid + 1, r,lazy[o])
                lazy[o] = 0
            
            ans = 0
            if mid >= L: ans += query(o * 2, l, mid, L, R)
            if mid < R: ans += query(o * 2 + 1, mid + 1, r, L, R)
            arr[o] = arr[o * 2] + arr[o * 2 + 1]
            return ans

        increase = 0
        res = 0
        d = {}
        for i,x in enumerate(nums,1):
            pos = d.get(x, 0)
            # update(1, 1, n, pos+1, i)
            increase += 2 * query(1, 1, n, pos+1, i) + i - pos
            res += increase
            d[x] = i
        return res % (10**9+7)


Explain

该题解采用了线段树结合延迟更新的技巧来处理子数组不同元素数目的计算。首先,线段树用于高效地管理和更新子数组的不同元素数目。延迟更新机制用于在必要时批量更新子数组信息,以提高效率。算法的核心在于维护每个元素最后一次出现的位置,并使用线段树在每次遍历新元素时,更新该元素之后的所有子数组的不同元素计数。这样,每遍历一个新元素,就可以计算出以该元素结尾的所有子数组的不同元素计数增量,并累加到结果中。

时间复杂度: O(n log n)

空间复杂度: O(n)

class Solution:
    def sumCounts(self, nums: List[int]) -> int:
        n = len(nums)
        arr = [0] * (4 * n)  # 线段树数组
        lazy = [0] * (4 * n)  # 延迟更新数组

        def do(o, l, r, op):
            # 延迟更新操作
            arr[o] += op * (r - l + 1)
            lazy[o] += op

        def query(o, l, r, L, R) -> int:
            # 线段树区间查询
            if L <= l and r <= R:
                ans = arr[o]
                do(o, l, r, 1)
                return ans

            mid = (l + r) // 2

            if lazy[o] != 0:
                do(o * 2, l, mid, lazy[o])
                do(o * 2 + 1, mid + 1, r, lazy[o])
                lazy[o] = 0

            ans = 0
            if mid >= L: ans += query(o * 2, l, mid, L, R)
            if mid < R: ans += query(o * 2 + 1, mid + 1, r, L, R)
            arr[o] = arr[o * 2] + arr[o * 2 + 1]
            return ans

        increase = 0
        res = 0
        d = {}
        for i, x in enumerate(nums, 1):
            pos = d.get(x, 0)
            increase += 2 * query(1, 1, n, pos + 1, i) + i - pos
            res += increase
            d[x] = i
        return res % (10**9 + 7)

Explore

线段树作为一种二叉树结构,通常存储在一个数组中。为了确保线段树可以覆盖所有区间的情况,选择其大小为原数组长度的四倍是常见的实践。这是因为在最坏的情况下,线段树可能需要更多的空间来存储所有分割的区间,尤其是当区间被细分到单个元素时。这个大小不仅可以容纳所有区间节点,还包括了必要的父节点和子节点,确保树的完整性和功能的实现。虽然这看起来有时可能过大,但它提供了足够的空间保证,避免了动态调整大小的复杂性,通常线段树的空间复杂度为O(4n)是可以接受的,以换取时间效率。

延迟更新技术(lazy propagation)是优化线段树操作的一种方法,特别适用于大量的区间更新操作。在此题解中,当需要更新一个区间内所有元素的值时,而不立即更新每个单独的元素,延迟更新技术通过标记未完成的更新操作来延迟实际的更新过程。这样,在进行区间查询或需要的时候,才会向下传递这些更新,从而减少不必要的计算量。具体实现中,每个节点会有一个额外的延迟标记数组`lazy`,用于记录该节点下所有子节点需要增加的值。在查询或更新操作中,如果遇到有延迟标记的节点,会先处理这个延迟标记,将其应用到子节点上,并清除当前节点的延迟标记,以保证数据的准确性。这种方法大大提高了效率,尤其是在处理大规模数据时。

数组`d`用于记录每个元素最后一次出现的位置,是解决此问题的关键部分之一。尽管这会增加空间复杂度,并在每次元素出现时更新其值,但这种更新操作的时间复杂度为O(1),通常不会显著影响算法的整体效率。然而,当元素种类非常多时,维护这样一个数组可能会消耗较多的内存。为优化这一点,可以考虑使用哈希表代替数组来存储元素的最后位置,特别是当元素范围很大或不连续时。哈希表在这种情况下可以节省空间,且仍然保持常数时间的访问和更新效率。