所有子数组中不平衡数字之和

标签: 数组 哈希表 有序集合

难度: Hard

一个长度为 n 下标从 0 开始的整数数组 arr 的 不平衡数字 定义为,在 sarr = sorted(arr) 数组中,满足以下条件的下标数目:

  • 0 <= i < n - 1 ,和
  • sarr[i+1] - sarr[i] > 1

这里,sorted(arr) 表示将数组 arr 排序后得到的数组。

给你一个下标从 0 开始的整数数组 nums ,请你返回它所有 子数组 的 不平衡数字 之和。

子数组指的是一个数组中连续一段 非空 的元素序列。

示例 1:

输入:nums = [2,3,1,4]
输出:3
解释:总共有 3 个子数组有非 0 不平衡数字:
- 子数组 [3, 1] ,不平衡数字为 1 。
- 子数组 [3, 1, 4] ,不平衡数字为 1 。
- 子数组 [1, 4] ,不平衡数字为 1 。
其他所有子数组的不平衡数字都是 0 ,所以所有子数组的不平衡数字之和为 3 。

示例 2:

输入:nums = [1,3,3,3,5]
输出:8
解释:总共有 7 个子数组有非 0 不平衡数字:
- 子数组 [1, 3] ,不平衡数字为 1 。
- 子数组 [1, 3, 3] ,不平衡数字为 1 。
- 子数组 [1, 3, 3, 3] ,不平衡数字为 1 。
- 子数组 [1, 3, 3, 3, 5] ,不平衡数字为 2 。
- 子数组 [3, 3, 3, 5] ,不平衡数字为 1 。
- 子数组 [3, 3, 5] ,不平衡数字为 1 。
- 子数组 [3, 5] ,不平衡数字为 1 。
其他所有子数组的不平衡数字都是 0 ,所以所有子数组的不平衡数字之和为 8 。

提示:

  • 1 <= nums.length <= 1000
  • 1 <= nums[i] <= nums.length

Submission

运行时间: 35 ms

内存: 16.1 MB

class Solution:
    def sumImbalanceNumbers(self, nums: List[int]) -> int:
        n = len(nums)
        right = [0] * n  # nums[i] 右侧的 x 和 x-1 的最近下标(不存在时为 n)
        idx = [n] * (n + 1)
        for i in range(n - 1, -1, -1):
            x = nums[i]
            right[i] = min(idx[x], idx[x - 1])
            idx[x] = i

        ans = 0
        idx = [-1] * (n + 1)
        for i, (x, r) in enumerate(zip(nums, right)):
            # 统计 x 能产生多少贡献
            ans += (i - idx[x - 1]) * (r - i)  # 子数组左端点个数 * 子数组右端点个数
            idx[x] = i
        # 上面计算的时候,每个子数组的最小值必然可以作为贡献,而这是不合法的
        # 所以每个子数组都多算了 1 个不合法的贡献
        return ans - n * (n + 1) // 2

Explain

题解通过两个主要步骤来解决问题:第一步是使用两个数组 right 和 idx 来存储关于每个元素右侧最近的 x 和 x-1 的信息。具体地,right 数组记录对于每个元素 nums[i],在它右侧最近的 x 或 x-1 的索引,如果不存在,则记录为 n。第二步是通过扫描每个元素并计算每个元素对于子数组不平衡数字的贡献。计算方式是,对于每个元素 x,在数组 nums 中位置为 i,它对于所有以 x 为最小值的子数组的贡献都是由两部分组成:从该位置左侧最近的 x-1 的位置到 i 的所有可能的子数组的数量与从 i 到它右侧最近的 x 或 x-1 的位置的所有可能的子数组的数量的乘积。由于在计算过程中,每个子数组的最小值被错误地计入了贡献,因此最后需要从总贡献中减去这部分不合法的贡献。

时间复杂度: O(n)

空间复杂度: O(n)

# Python 3 code with comments

class Solution:
    def sumImbalanceNumbers(self, nums: List[int]) -> int:
        n = len(nums)
        right = [0] * n  # nums[i] 右侧的 x 和 x-1 的最近下标(不存在时为 n)
        idx = [n] * (n + 1)
        for i in range(n - 1, -1, -1):
            x = nums[i]
            right[i] = min(idx[x], idx[x - 1])
            idx[x] = i

        ans = 0
        idx = [-1] * (n + 1)
        for i, (x, r) in enumerate(zip(nums, right)):
            # 统计 x 能产生多少贡献
            ans += (i - idx[x - 1]) * (r - i)  # 子数组左端点个数 * 子数组右端点个数
            idx[x] = i
        # 上面计算的时候,每个子数组的最小值必然可以作为贡献,而这是不合法的
        # 所以每个子数组都多算了 1 个不合法的贡献
        return ans - n * (n + 1) // 2

Explore

在这种问题中,使用数组`right`和`idx`的方法主要是为了高效地访问和更新元素的索引信息。数组可以提供O(1)的时间复杂度进行索引访问和更新,这对于算法的整体性能至关重要。虽然可以使用其他数据结构如哈希表来实现类似的功能,哈希表提供平均O(1)时间复杂度的访问和更新,但是相对于直接使用数组,哈希表在处理碰撞时可能会有额外的性能开销,并且在空间使用上通常比数组要高。因此,在这个特定情况下,直接使用数组是一个既简单又高效的选择。

在题解中,`idx`数组的每个位置用来存储某个数值最后一次出现的索引。将`idx`数组初始化为`n`(`nums`的长度)是为了处理那些在数组`nums`中未出现或者在某个位置右侧不再出现的情况。当`idx[x]`为`n`时,表示数值`x`在当前位置右侧不存在,这是一种边界标记方法,使得算法能正确处理所有元素的右侧边界条件。这种初始化方法简化了边界条件的处理,避免了额外的条件判断,从而提高了代码的整洁性和执行效率。

这种计算方式是基于子数组结构的性质。对于每个元素`x`于位置`i`,要找到所有以`x`为最小值的子数组。子数组可以从`i`的左侧任何一个比`x`大的位置开始,直到遇到一个比`x`小的数(即`x-1`),同样子数组可以向右延伸到任何一个位置,直到`x`或`x-1`。因此,对于每个位置`i`,可以通过`i`左侧到最近的`x-1`的距离(子数组可能的左端点数)乘以从`i`到右侧最近的`x`或`x-1`的距离(子数组可能的右端点数)来计算以`i`为起点的所有子数组中,`x`作为最小值的情况数。这种计算方法能够准确并有效地统计出所有符合条件的子数组,因此被采用在算法中。