找到所有好字符串

标签: 字符串 动态规划 字符串匹配

难度: Hard

给你两个长度为 n 的字符串 s1 和 s2 ,以及一个字符串 evil 。请你返回 好字符串 的数目。

好字符串 的定义为:它的长度为 n ,字典序大于等于 s1 ,字典序小于等于 s2 ,且不包含 evil 为子字符串。

由于答案可能很大,请你返回答案对 10^9 + 7 取余的结果。

示例 1:

输入:n = 2, s1 = "aa", s2 = "da", evil = "b"
输出:51 
解释:总共有 25 个以 'a' 开头的好字符串:"aa","ac","ad",...,"az"。还有 25 个以 'c' 开头的好字符串:"ca","cc","cd",...,"cz"。最后,还有一个以 'd' 开头的好字符串:"da"。

示例 2:

输入:n = 8, s1 = "leetcode", s2 = "leetgoes", evil = "leet"
输出:0 
解释:所有字典序大于等于 s1 且小于等于 s2 的字符串都以 evil 字符串 "leet" 开头。所以没有好字符串。

示例 3:

输入:n = 2, s1 = "gx", s2 = "gz", evil = "x"
输出:2

提示:

  • s1.length == n
  • s2.length == n
  • s1 <= s2
  • 1 <= n <= 500
  • 1 <= evil.length <= 50
  • 所有字符串都只包含小写英文字母。

Submission

运行时间: 206 ms

内存: 21.1 MB

class Solution:
    def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
        def prefix_function(pattern: str) -> []:
            n = len(pattern)
            pi = [0] * n
            for i in range(1, n):
                j = pi[i - 1]
                while j != 0 and pattern[i] != pattern[j]:
                    j = pi[j - 1]
                if pattern[i] == pattern[j]:
                    j += 1
                pi[i] = j
            return pi

        def match_char(pattern: str) -> []:
            n = len(pattern)
            pi = prefix_function(pattern)
            dp = [[0 for j in range(26)] for i in range(n)]
            dp[0][ord(pattern[0]) - 97] = 1
            for i in range(1, n):
                for cur in range(0, 26):
                    ch = chr(cur + 97)
                    if ch == pattern[i]:
                        dp[i][cur] = i + 1
                    else:
                        j = pi[i - 1]
                        dp[i][cur] = dp[j][cur]
            return dp

        n = len(s2)
        pn = len(evil)
        dp = match_char(evil)

        @lru_cache(None)
        def dfs(i:int, match:int, limit_low:bool, limit_high:bool) -> int:
            if i == n:
                return match < pn
            if match >= pn:
                return 0
            lo = ord(s1[i]) if limit_low else ord('a')
            hi = ord(s2[i]) if limit_high else ord('z')
            res = 0
            for cur in range(lo, hi + 1):
                res += dfs(i+1, dp[match][cur - 97], limit_low and cur == lo, limit_high and cur == hi)
            res %= (10 ** 9 + 7)
            return res
        return dfs(0, 0, True, True)

Explain

这道题的解法使用了动态规划和KMP算法的思想。首先通过KMP算法的前缀函数计算出evil字符串的前缀函数,然后基于前缀函数构建一个dp数组,dp[i][j]表示evil字符串的前i个字符与某个字符串匹配,且当前匹配到字符j时,evil字符串的最大匹配长度。接下来使用记忆化搜索,枚举满足条件的字符串,搜索过程中维护当前枚举到的位置i,当前匹配的evil字符串的长度match,以及是否受到s1和s2的字典序限制。搜索到字符串末尾时,如果match小于evil字符串的长度,说明当前字符串是好字符串,累加答案。

时间复杂度: O(n × m × 26)

空间复杂度: O(n × m + m × 26)

class Solution:
    def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
        def prefix_function(pattern: str) -> []:
            # 计算字符串pattern的前缀函数
            n = len(pattern)
            pi = [0] * n
            for i in range(1, n):
                j = pi[i - 1]
                while j != 0 and pattern[i] != pattern[j]:
                    j = pi[j - 1]
                if pattern[i] == pattern[j]:
                    j += 1
                pi[i] = j
            return pi

        def match_char(pattern: str) -> []:
            # 构建dp数组,dp[i][j]表示pattern的前i个字符与某个字符串匹配,且当前匹配到字符j时,pattern的最大匹配长度
            n = len(pattern)
            pi = prefix_function(pattern)
            dp = [[0 for j in range(26)] for i in range(n)]
            dp[0][ord(pattern[0]) - 97] = 1
            for i in range(1, n):
                for cur in range(0, 26):
                    ch = chr(cur + 97)
                    if ch == pattern[i]:
                        dp[i][cur] = i + 1
                    else:
                        j = pi[i - 1]
                        dp[i][cur] = dp[j][cur]
            return dp

        n = len(s2)
        pn = len(evil)
        dp = match_char(evil)

        @lru_cache(None)
        def dfs(i:int, match:int, limit_low:bool, limit_high:bool) -> int:
            # 记忆化搜索,枚举满足条件的字符串
            if i == n:
                return match < pn
            if match >= pn:
                return 0
            lo = ord(s1[i]) if limit_low else ord('a')
            hi = ord(s2[i]) if limit_high else ord('z')
            res = 0
            for cur in range(lo, hi + 1):
                res += dfs(i+1, dp[match][cur - 97], limit_low and cur == lo, limit_high and cur == hi)
            res %= (10 ** 9 + 7)
            return res
        return dfs(0, 0, True, True)

Explore

在KMP算法中,前缀函数被用来表示字符串的最长相同前后缀的长度。当计算前缀函数时,如果当前字符与j指向的字符不匹配,这意味着当前的匹配失败。因此,我们需要找到一个更短的有效匹配,这可以通过跳转到pi[j - 1]来实现,即跳到当前已匹配前缀的下一个最长相同前后缀的末尾。这样做可以避免重复检查已知不匹配的字符,从而提高匹配效率。

dp数组的每个元素dp[i][j]表示在字符串pattern中,以i结尾的子串能够匹配的最大长度,当下一个字符为j(ASCII码减去97得到字符'a'到'z'的索引)时的情况。对于dp[0][ord(pattern[0]) - 97] = 1是因为第一个字符匹配时,匹配长度为1。对于其他值,当字符不匹配时,我们需要回溯到前一个匹配的状态,这由前缀函数pi给出。因此,对于dp数组中的其他元素,我们根据前缀函数计算得到的最长可匹配后缀的继续位置来填充,这保证了在非直接匹配的情况下,能找到最长的可能的匹配状态。

在dfs函数中,参数limit_low和limit_high用于控制生成字符串的字典序边界以确保其在s1和s2的范围内。当limit_low为true时,表示当前位置的字符不能小于s1在同一位置的字符。类似地,当limit_high为true时,表示当前位置的字符不能大于s2在同一位置的字符。这两个参数在递归过程中动态更新:如果当前字符恰好等于s1(或s2)的对应字符,相应的限制参数保持为true,否则设为false。这种机制确保了生成的字符串在给定的字典序范围内递归搜索。

在dfs函数中,如果match等于或超过evil字符串的长度,这意味着在生成的字符串中已经完全包含了evil字符串作为子串。题目要求找到不包含evil字符串的好字符串。因此,一旦发现evil字符串已经被完全匹配,当前分支的搜索可以终止,直接返回0,表示这条路径不产生任何有效的好字符串,这是为了满足问题的约束条件而进行的优化处理。