匹配模式数组的子数组数目 II

标签: 数组 字符串匹配 哈希函数 滚动哈希

难度: Hard

给你一个下标从 0 开始长度为 n 的整数数组 nums ,和一个下标从 0 开始长度为 m 的整数数组 pattern ,pattern 数组只包含整数 -1 ,0 和 1 。

大小为 m + 1 的子数组 nums[i..j] 如果对于每个元素 pattern[k] 都满足以下条件,那么我们说这个子数组匹配模式数组 pattern :

  • 如果 pattern[k] == 1 ,那么 nums[i + k + 1] > nums[i + k]
  • 如果 pattern[k] == 0 ,那么 nums[i + k + 1] == nums[i + k]
  • 如果 pattern[k] == -1 ,那么 nums[i + k + 1] < nums[i + k]

请你返回匹配 pattern 的 nums 子数组的 数目 。

示例 1:

输入:nums = [1,2,3,4,5,6], pattern = [1,1]
输出:4
解释:模式 [1,1] 说明我们要找的子数组是长度为 3 且严格上升的。在数组 nums 中,子数组 [1,2,3] ,[2,3,4] ,[3,4,5] 和 [4,5,6] 都匹配这个模式。
所以 nums 中总共有 4 个子数组匹配这个模式。

示例 2:

输入:nums = [1,4,4,1,3,5,5,3], pattern = [1,0,-1]
输出:2
解释:这里,模式数组 [1,0,-1] 说明我们需要找的子数组中,第一个元素小于第二个元素,第二个元素等于第三个元素,第三个元素大于第四个元素。在 nums 中,子数组 [1,4,4,1] 和 [3,5,5,3] 都匹配这个模式。
所以 nums 中总共有 2 个子数组匹配这个模式。

提示:

  • 2 <= n == nums.length <= 106
  • 1 <= nums[i] <= 109
  • 1 <= m == pattern.length < n
  • -1 <= pattern[i] <= 1

Submission

运行时间: 199 ms

内存: 57.5 MB

class Solution:
    def countMatchingSubarrays(self, nums: List[int], pattern: List[int]) -> int:
        s = []
        for a, b in pairwise(nums):
            if b > a:
                s.append('c')
            elif b == a:
                s.append('b')
            else:
                s.append('a')
        
        s = ''.join(s)
        t = ''.join(['abc'[p + 1] for p in pattern])
        
        def prefix_function(s):
            n = len(s)
            pi = [0] * n
            for i in range(1, n):
                j = pi[i - 1]
                while j > 0 and s[i] != s[j]:
                    j = pi[j - 1]
                if s[i] == s[j]:
                    j += 1
                pi[i] = j
            return pi
        
        def find_occurrences(t, s):
            cur = s + '#' + t
            sz1, sz2 = len(t), len(s)
            ret = []
            lps = prefix_function(cur)
            for i in range(sz2 + 1, sz1 + sz2 + 1):
                if lps[i] == sz2:
                    ret.append(i - 2 * sz2)
            return ret

        res = len(find_occurrences(s, t))
    
        return res 
        

Explain

本题解首先将`nums`数组转换为一个字符序列`s`,其中每个字符代表相邻元素之间的关系:'a' 表示递减,'b' 表示相等,'c' 表示递增。接着,根据`pattern`数组生成一个目标模式字符串`t`,其中每个字符是基于`pattern`元素映射得到(-1映射为'a', 0映射为'b', 1映射为'c')。随后,代码利用字符串匹配的KMP算法中的前缀函数来查找字符串`t`在`s`中的所有出现位置,每个匹配的位置都代表一个符合条件的子数组。最后,返回这些匹配的数量。

时间复杂度: O(n + m)

空间复杂度: O(n + m)

class Solution:
    def countMatchingSubarrays(self, nums: List[int], pattern: List[int]) -> int:
        s = []
        # 构建表示元素关系的字符串
        for a, b in pairwise(nums):
            if b > a:
                s.append('c')  # 递增
            elif b == a:
                s.append('b')  # 相等
            else:
                s.append('a')  # 递减
        
        s = ''.join(s)
        # 根据pattern构建目标模式字符串
        t = ''.join(['abc'[p + 1] for p in pattern])
        
        # KMP算法的前缀函数
        def prefix_function(s):
            n = len(s)
            pi = [0] * n
            for i in range(1, n):
                j = pi[i - 1]
                while j > 0 and s[i] != s[j]:
                    j = pi[j - 1]
                if s[i] == s[j]:
                    j += 1
                pi[i] = j
            return pi
        
        # 使用前缀函数找出所有匹配的位置
        def find_occurrences(t, s):
            cur = s + '#' + t
            sz1, sz2 = len(t), len(s)
            ret = []
            lps = prefix_function(cur)
            for i in range(sz2 + 1, sz1 + sz2 + 1):
                if lps[i] == sz2:
                    ret.append(i - 2 * sz2)
            return ret

        # 计算匹配的数量
        res = len(find_occurrences(s, t))
    
        return res

Explore

如果nums数组长度小于2,则无法形成任何相邻元素的比较关系。因此,在这种情况下,字符串s将为空字符串。由于没有可比较的元素关系,直接返回匹配模式的数量为0会是合理的处理方式。

KMP算法被选择用于字符串匹配主要是因为它提供了线性时间复杂度(O(n+m),其中n是文本长度,m是模式长度),且不需要额外的空间复杂度(除了存储前缀函数所需的空间)。这使得KMP算法非常适合在字符串搜索中实现高效和稳定的性能。相比之下,Rabin-Karp算法虽然在平均情况下也有较好的性能,但在最坏情况下会退化到O(n*m);而Boyer-Moore算法虽然在最好情况下非常高效,但其实现较为复杂,且在最坏情况下性能也可能不稳定。

在题目中,pattern数组应只包含-1,0和1这三种值,分别对应于递减、相等和递增的关系。如果pattern数组包含除这三个值之外的其他整数,这将是一个异常情况,因为没有定义如何将这些值转换为'a'、'b'或'c'。在实际实现中,应当对输入进行验证,确保不包含无效的整数。如果遇到无效的整数,可以抛出异常或返回错误信息,提示输入数据不符合预期。