统计数组中好三元组数目

标签: 树状数组 线段树 数组 二分查找 分治 有序集合 归并排序

难度: Hard

给你两个下标从 0 开始且长度为 n 的整数数组 nums1 和 nums2 ,两者都是 [0, 1, ..., n - 1] 的 排列 。

好三元组 指的是 3 个 互不相同 的值,且它们在数组 nums1 和 nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 v 在 nums1 中出现的位置,pos2v 为值 v 在 nums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1z 和 pos2x < pos2y < pos2z 都成立的 (x, y, z) 。

请你返回好三元组的 总数目 。

示例 1:

输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。

示例 2:

输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。

提示:

  • n == nums1.length == nums2.length
  • 3 <= n <= 105
  • 0 <= nums1[i], nums2[i] <= n - 1
  • nums1 和 nums2 是 [0, 1, ..., n - 1] 的排列。

Submission

运行时间: 448 ms

内存: 32.6 MB

from sortedcontainers import SortedList

class Solution:
    def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        n = len(nums1)
        p = [0] * n
        for i, x in enumerate(nums1):
            p[x] = i
        ans = 0
        s = SortedList()
        for i in range(1, n - 1):
            s.add(p[nums2[i - 1]])
            y = p[nums2[i]]
            less = s.bisect_left(y)
            ans += less * (n - 1 - y - (i - less))
        return ans

Explain

这个题解采用了一个有序集合来优化查找和插入操作。首先,使用一个数组 p 存储 nums1 中每个值的索引,使得 p[nums2[i]] 可以快速获得 nums2 中元素在 nums1 中的位置。接着,通过遍历 nums2 来构建好的三元组。对于 nums2 中的每一个元素 nums2[i](作为三元组中的 y),使用一个有序集合 s 来维护已经遍历过的 nums2 元素在 nums1 中的位置。对于每个元素 y,通过二分查找其在 s 中的位置,可以得到小于 y 的元素数量(即作为 x 的候选数量)。同时,计算出大于 y 的元素数量(即作为 z 的候选数量),这样就可以确定以 y 为中心的好三元组数量。每次循环结束时,将当前 y 加入集合 s 中,以便后续查找和统计。

时间复杂度: O(n log n)

空间复杂度: O(n)

from sortedcontainers import SortedList

class Solution:
    def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        n = len(nums1)
        p = [0] * n  # p[x]表示值x在nums1中的位置
        for i, x in enumerate(nums1):
            p[x] = i
        ans = 0
        s = SortedList()  # 用于存储遍历过的nums2[i-1]在nums1中的位置
        for i in range(1, n - 1):
            s.add(p[nums2[i - 1]])  # 将前一个元素的位置添加到有序集合中
            y = p[nums2[i]]  # 当前元素nums2[i]在nums1中的位置
            less = s.bisect_left(y)  # 计算在y之前的元素数量,即小于y的位置数量
            ans += less * (n - 1 - y - (i - less))  # 计算以y为中心的好三元组数量
        return ans

Explore

在算法中,有序集合s通过使用平衡二叉搜索树(如红黑树)或者排序数组来确保元素的顺序。在Python的`sortedcontainers`模块中,`SortedList`通常是基于排序数组实现的,它保持元素在内部数组中排序,从而能够快速进行二分查找和有序插入操作。

选择`SortedList`而不是普通列表或哈希表是因为`SortedList`可以更高效地支持插入和二分查找操作。普通列表虽然可以通过`bisect`模块支持二分查找,但插入操作(尤其是中间插入)的效率较低,因为它需要移动后续的所有元素。哈希表虽然插入和查找操作的平均时间复杂度为O(1),但它不保持元素的顺序,因此无法直接用来获取小于给定值的元素个数。

`less = s.bisect_left(y)`步骤通过二分查找算法在有序集合s中查找元素y应插入的位置,以保持集合的有序性。该方法返回的是y在s中的索引位置,如果s中存在y,这个位置指向第一个y;如果s中不存在y,这个位置是y应当被插入的位置。这个索引同时也表示了s中小于y的元素的数量。

该公式用于计算以元素y为中心的好三元组的数量。这里的`less`是y左侧(小于y的元素)的数量。`n-1-y`是在nums1中y右侧的元素数量。由于我们已经遍历了i个元素,我们需要从这个右侧元素总数中减去已经包括在遍历中的元素数量`i-less`(即y右侧但在遍历中的元素数量),从而得到`n-1-y-(i-less)`。这表示对于每一个小于y的元素(共less个),有`n-1-y-(i-less)`种选择作为z。因此,各个y可以形成的好三元组数量就是`less`乘以这个值。