由单个字符重复的最长子字符串

标签: 线段树 数组 字符串 有序集合

难度: Hard

给你一个下标从 0 开始的字符串 s 。另给你一个下标从 0 开始、长度为 k 的字符串 queryCharacters ,一个下标从 0 开始、长度也是 k 的整数 下标 数组 queryIndices ,这两个都用来描述 k 个查询。

i 个查询会将 s 中位于下标 queryIndices[i] 的字符更新为 queryCharacters[i]

返回一个长度为 k 的数组 lengths ,其中 lengths[i] 是在执行第 i 个查询 之后 s 中仅由 单个字符重复 组成的 最长子字符串长度

示例 1:

输入:s = "babacc", queryCharacters = "bcb", queryIndices = [1,3,3]
输出:[3,3,4]
解释:
- 第 1 次查询更新后 s = "bbbacc" 。由单个字符重复组成的最长子字符串是 "bbb" ,长度为 3 。
- 第 2 次查询更新后 s = "bbbccc" 。由单个字符重复组成的最长子字符串是 "bbb" 或 "ccc",长度为 3 。
- 第 3 次查询更新后 s = "bbbbcc" 。由单个字符重复组成的最长子字符串是 "bbbb" ,长度为 4 。
因此,返回 [3,3,4] 。

示例 2:

输入:s = "abyzz", queryCharacters = "aa", queryIndices = [2,1]
输出:[2,3]
解释:
- 第 1 次查询更新后 s = "abazz" 。由单个字符重复组成的最长子字符串是 "zz" ,长度为 2 。
- 第 2 次查询更新后 s = "aaazz" 。由单个字符重复组成的最长子字符串是 "aaa" ,长度为 3 。
因此,返回 [2,3] 。

提示:

  • 1 <= s.length <= 105
  • s 由小写英文字母组成
  • k == queryCharacters.length == queryIndices.length
  • 1 <= k <= 105
  • queryCharacters 由小写英文字母组成
  • 0 <= queryIndices[i] < s.length

Submission

运行时间: 3867 ms

内存: 36.7 MB

from itertools import groupby
from typing import List
from sortedcontainers import SortedList, SortedDict


class Solution:
    def longestRepeating(self, s: str, queryCharacters: str, queryIndices: List[int]) -> List[int]:
        def split(index: int):
            if index < 0  or index >= n:
                return
            #找到最后一个大于等于index的元素的位置
            curPos = sMap.bisect_right(index) - 1
            s1, e1 = sMap.peekitem(curPos)
            #如果index是某一个区间的起点 那么直接退出
            if s1 == index:
                return
            sMap.popitem(curPos)
            sMap[s1] = index - 1
            sMap[index] = e1
            sList.remove(e1 - s1 + 1)
            sList.add(index - 1 - s1 + 1)
            sList.add(e1 - index + 1)
        
        def union(index: int) -> None:
            """
            如果以 index 为起点的区间和其前一个区间内的字符相同,合并两个区间

            """
            if index < 0 or index >= n:
                return
            #找到index所在区间
            curPos = sMap.bisect_right(index) - 1
            prePos = curPos - 1
            if prePos < 0:
                return
            (s1, e1), (s2, e2) = sMap.peekitem(prePos), sMap.peekitem(curPos)
            if chars[s1] == chars[s2]:
                sMap.popitem(curPos)
                sMap[s1] = e2
                sList.remove(e2 - s2 + 1)
                sList.remove(e1 - s1 + 1)
                sList.add(e2 - s1 + 1)
        sMap = SortedDict()  # start=>end
        sList = SortedList(key=lambda x: -x)

        # 1.初始化
        start = 0
        for _, group in groupby(s):
            len_ = len(list(group))
            sMap[start] = start + len_ - 1
            sList.add(len_)
            start += len_

        res = [0] * len(queryIndices)
        n, chars = len(s), list(s)

        for i, (qc, qi) in enumerate(zip(queryCharacters, queryIndices)):
            if chars[qi] == qc:
                res[i] = sList[0]
                continue

            chars[qi] = qc

            # 断开qi
            split(qi)
            split(qi + 1)

            # 向左连接
            union(qi)
            union(qi + 1)

            res[i] = sList[0]

        return res

Explain

题解使用了SortedDict和SortedList来维护子字符串的信息。首先,初始化过程中,通过groupby将字符串s划分为由相同字符组成的多个子字符串,并存储这些子字符串的起始和结束位置到SortedDict中,同时将子字符串的长度存入SortedList中。对于每次查询,先检查新字符是否与原字符相同,若相同则直接返回当前最长重复子字符串长度;若不同,则更新字符,并分割和合并相关区间以维护正确的子字符串信息。分割操作确保每次查询只影响一个很小的区间,而合并操作尝试将相邻的由相同字符组成的区间合并起来。通过这种方式,每次查询后都能快速获取最长重复子字符串的长度。

时间复杂度: O(n + k log n)

空间复杂度: O(n)

from itertools import groupby
from typing import List
from sortedcontainers import SortedList, SortedDict

class Solution:
    def longestRepeating(self, s: str, queryCharacters: str, queryIndices: List[int]) -> List[int]:
        def split(index: int):
            if index < 0  or index >= n:
                return
            curPos = sMap.bisect_right(index) - 1
            s1, e1 = sMap.peekitem(curPos)
            if s1 == index:
                return
            sMap.popitem(curPos)
            sMap[s1] = index - 1
            sMap[index] = e1
            sList.remove(e1 - s1 + 1)
            sList.add(index - 1 - s1 + 1)
            sList.add(e1 - index + 1)
        
        def union(index: int) -> None:
            if index < 0 or index >= n:
                return
            curPos = sMap.bisect_right(index) - 1
            prePos = curPos - 1
            if prePos < 0:
                return
            (s1, e1), (s2, e2) = sMap.peekitem(prePos), sMap.peekitem(curPos)
            if chars[s1] == chars[s2]:
                sMap.popitem(curPos)
                sMap[s1] = e2
                sList.remove(e2 - s2 + 1)
                sList.remove(e1 - s1 + 1)
                sList.add(e2 - s1 + 1)
        sMap = SortedDict()  # start=>end
        sList = SortedList(key=lambda x: -x)
        start = 0
        for _, group in groupby(s):
            len_ = len(list(group))
            sMap[start] = start + len_ - 1
            sList.add(len_)
            start += len_
        res = [0] * len(queryIndices)
        n, chars = len(s), list(s)
        for i, (qc, qi) in enumerate(zip(queryCharacters, queryIndices)):
            if chars[qi] == qc:
                res[i] = sList[0]
                continue
            chars[qi] = qc
            split(qi)
            split(qi + 1)
            union(qi)
            union(qi + 1)
            res[i] = sList[0]
        return res

Explore

在Python中,`itertools.groupby`函数可以将连续的相同元素分组。传给`groupby`的参数是字符串`s`,它会根据每个字符的相等性来分组。每次迭代返回一个键和一个迭代器,键是组中的字符,迭代器包含该字符的连续重复序列。因此,这种方式自然地确保了每个通过`groupby`生成的子序列都是由完全相同的字符构成。

在`s`中的特定`index`处执行`split`函数的目的是为了正确地处理字符的更新,即使该字符与原位置字符相同。分割主要是因为更新操作可能影响已有区间的结构。例如,如果在`index`处的字符被更新(即使更新后字符相同),我们可能需要重新划分边界来确保后续的`union`操作可以正确地合并新的相同字符区间。这样的设计确保了数据结构的一致性和查询的准确性。

在`union`函数中,合并两个相邻区间的决策基于两个主要条件:1) 这两个区间是否是相邻的;2) 区间的边界字符是否相同。首先,函数检查当前区间和前一个区间的位置,如果它们是连续的(即一个区间的结束位置与另一个区间的开始位置相邻),然后检查这两个区间的字符是否相同。如果这两个条件都满足,那么这两个区间可以合并成一个更大的区间。这样的合并有助于维护和更新最长连续相同字符子串的长度。

在更新`SortedDict`和`SortedList`时,首先要做的是从`SortedList`中删除旧的区间长度。这是通过直接调用`remove`方法完成的,该方法根据原有的区间长度删除相关的条目。紧接着,当新的区间长度形成(无论是通过分割还是合并操作),新的长度将被添加到`SortedList`中,使用`add`方法。这种方式确保`SortedList`始终保持最新的区间长度信息,从而可以快速返回最大的区间长度。