统计异或值在范围内的数对有多少

标签: 位运算 字典树 数组

难度: Hard

给你一个整数数组 nums (下标 从 0 开始 计数)以及两个整数:lowhigh ,请返回 漂亮数对 的数目。

漂亮数对 是一个形如 (i, j) 的数对,其中 0 <= i < j < nums.lengthlow <= (nums[i] XOR nums[j]) <= high

 

示例 1:

输入:nums = [1,4,2,7], low = 2, high = 6
输出:6
解释:所有漂亮数对 (i, j) 列出如下:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

示例 2:

输入:nums = [9,8,4,2,1], low = 5, high = 14
输出:8
解释:所有漂亮数对 (i, j) 列出如下:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

 

提示:

  • 1 <= nums.length <= 2 * 104
  • 1 <= nums[i] <= 2 * 104
  • 1 <= low <= high <= 2 * 104

Submission

运行时间: 252 ms

内存: 18.7 MB

class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        ans, cnt = 0, Counter(nums)
        high += 1
        while high:
            nxt = Counter()
            for x, c in cnt.items():
                if high & 1: ans += c * cnt[x ^ (high - 1)]
                if low & 1:  ans -= c * cnt[x ^ (low - 1)]
                nxt[x >> 1] += c
            cnt = nxt
            low >>= 1
            high >>= 1
        return ans // 2

Explain

该题解采用了基于计数和二分字典树(Trie)的思想,但实际上未使用树结构。算法通过逐位考察数字,并使用计数器来跟踪相同值的出现次数。对每一位的处理,算法检查当前位的0或1是否会让XOR结果在low和high范围内,并相应地更新计数器。对每个数字,都尝试与可能的XOR匹配进行配对,并调整答案计数。这种方法在处理每一位时都减少了不必要的计算,从而提高效率。

时间复杂度: O(n log C)

空间复杂度: O(n)

# 添加了注释的题解代码

class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        ans, cnt = 0, Counter(nums)  # 初始化答案为0和计数器
        high += 1  # 将high增加1简化比较
        while high:  # 当high不为0时循环
            nxt = Counter()  # 下一轮的计数器
            for x, c in cnt.items():  # 遍历当前计数器的每个数字及其出现次数
                if high & 1: ans += c * cnt[x ^ (high - 1)]  # 如果当前high的最低位为1,增加计数
                if low & 1:  ans -= c * cnt[x ^ (low - 1)]  # 如果当前low的最低位为1,减少计数
                nxt[x >> 1] += c  # 填充下一轮计数器
            cnt = nxt  # 更新计数器为下一轮的计数器
            low >>= 1  # 将low右移一位
            high >>= 1  # 将high右移一位
        return ans // 2  # 返回最终的答案,除以2是因为每对被计算了两次

Explore

在算法中将`high`增加1是为了将原本的`[low, high]`范围转换成`[low, high)`(半开区间),这使得在使用位运算处理时更加方便。具体来说,当我们检查二进制的每一位时,增加1后的`high`可以直接用来判断是否达到边界条件,而不需要额外的操作来处理边界情况。这种处理方法可以简化逻辑判断,使代码更易于理解和维护。

题解中的计数器用于跟踪每一位二进制上各数字的出现次数。在每一位的处理过程中,我们不是重新计算所有数字的计数,而是更新计数器来反映右移操作后的新值。这种方法通过避免重新计算已处理位的数字,大大提高了效率,因为它只处理当前需要考虑的位。

在算法中,`high & 1`和`low & 1`用于检查`high`和`low`的当前最低位是否为1。这种检查是因为我们需要确定当前位是否可以通过异或操作达到期望的0或1,从而使得结果仍然在[low, high)区间内。如果`high`的最低位为1,表示高边界在这一位上可以达到1,因此需要加上符合条件的数对计数。同理,如果`low`的最低位为1,表示低边界在这一位上要求至少为1,因此需要减去不符合条件的数对计数。这种基于位的处理方法使得算法在统计符合范围的数对时更加精确和高效。

在题解中,`nxt[x >> 1] += c`操作是将当前数字`x`右移一位后,将其计数加入到新的计数器`nxt`中。这一步骤的作用是为了在处理下一位之前,预先处理掉当前位的信息,从而只关注剩下的更高位。通过这种方式,算法可以在每一步只关注当前需要处理的位,而不用重新计算所有位,从而减少了不必要的计算量,并使得算法更加高效。