寻找两个正序数组的中位数

标签: 数组 二分查找 分治

难度: Hard

给定两个大小分别为 mn 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数

算法的时间复杂度应该为 O(log (m+n))

示例 1:

输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

提示:

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -106 <= nums1[i], nums2[i] <= 106

Submission

运行时间: 48 ms

内存: 15 MB

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m = len(nums1)
        n = len(nums2)
        if m > n:
            return self.findMedianSortedArrays(nums2, nums1)
        
        total_left = (m + n + 1) // 2
        l = 0
        r = m
        while l < r:
            i = l + (r - l + 1) // 2
            j = total_left - i
            if nums1[i-1] <= nums2[j]:
                l = i
            else:
                r = i - 1
        
        i = l
        j = total_left - i
        left_nums1_max = float('-inf') if i == 0 else nums1[i-1]
        left_nums2_max = float('-inf') if j == 0 else nums2[j-1]
        right_nums1_min = float('inf') if i == m else nums1[i]
        right_nums2_min = float('inf') if j == n else nums2[j]

        if (m + n) % 2 == 1:
            return max(left_nums1_max, left_nums2_max)
        else:
            return (max(left_nums1_max, left_nums2_max) + min(right_nums1_min, right_nums2_min)) / 2

Explain

该题解使用了二分查找的思想。首先确保较短的数组为 nums1,如果不是就交换。然后使用二分查找在 nums1 中找到一个分割点 i,使得 nums1[0...i-1] 和 nums2[0...j-1](j = (m+n+1)//2 - i)的元素个数之和等于 (m+n+1)//2。这样就可以保证左半部分的元素个数等于右半部分或者多一个。最后,根据总元素个数的奇偶性,返回左半部分的最大值或左右部分的最大最小值的平均数作为中位数。

时间复杂度: O(log min(m, n))

空间复杂度: O(1)

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m = len(nums1)
        n = len(nums2)
        if m > n:
            # 确保 nums1 是较短的数组
            return self.findMedianSortedArrays(nums2, nums1)
        
        total_left = (m + n + 1) // 2
        l = 0
        r = m
        while l < r:
            # 在 nums1 中二分查找分割点 i
            i = l + (r - l + 1) // 2
            j = total_left - i
            if nums1[i-1] <= nums2[j]:
                l = i
            else:
                r = i - 1
        
        i = l
        j = total_left - i
        # 获取左半部分的最大值
        left_nums1_max = float('-inf') if i == 0 else nums1[i-1]
        left_nums2_max = float('-inf') if j == 0 else nums2[j-1]
        # 获取右半部分的最小值
        right_nums1_min = float('inf') if i == m else nums1[i]
        right_nums2_min = float('inf') if j == n else nums2[j]

        if (m + n) % 2 == 1:
            # 如果总长度为奇数,中位数就是左半部分的最大值
            return max(left_nums1_max, left_nums2_max)
        else:
            # 如果总长度为偶数,中位数就是左半部分最大值和右半部分最小值的平均值
            return (max(left_nums1_max, left_nums2_max) + min(right_nums1_min, right_nums2_min)) / 2

Explore

确保较短的数组为 `nums1` 的主要目的是为了优化二分查找的效率和简化边界条件处理。当 `nums1` 较短时,我们在它上进行二分查找的次数会减少,因为其长度 `m` 较小,从而使得二分查找的最大可能次数为 `log(m)`。此外,处理边界情况(如 `i` 达到0或`m`)也会更简单,因为较短的数组中分割点的变动范围小。

使用 `i = l + (r - l + 1) // 2` 是为了确保 `i` 在两个候选区间的更靠右侧,这有助于防止在某些情况下出现死循环或过早收敛。这种写法确保在 `l` 和 `r` 相邻的情况下,选择的是靠右的位置,避免了可能的无限循环。

这里的比较方式 `nums1[i-1] <= nums2[j]` 是为了确保左半部分的最大值不大于右半部分的最小值。因为 `i` 和 `j` 是根据总元素数的一半来计算的,所以 `nums1[i-1]` 是左半部分的最大值,而 `nums2[j]` 是右半部分的最小值。通过确保左半部最大值不大于右半部最小值,我们可以保持元素的正确分割。使用 `nums1[i] <= nums2[j-1]` 这种比较方式则可能会错过正确的分割点,因为它关注的是左半部分的最小值和右半部分的最大值,而不是我们所需要的左半部分的最大值和右半部分的最小值。

总元素个数的奇偶性决定了中位数是单一元素还是两个元素的平均。如果总长度为奇数,则中间的元素直接是中位数;如果总长度为偶数,则中位数是中间两个元素的平均值。通过计算左半部分最大值和右半部分最小值,我们可以根据总元素的奇偶性来决定是直接取这个最大值,还是取这两个值的平均,以此来正确地计算中位数。