相似字符串组

标签: 深度优先搜索 广度优先搜索 并查集 数组 哈希表 字符串

难度: Hard

如果交换字符串 X 中的两个不同位置的字母,使得它和字符串 Y 相等,那么称 XY 两个字符串相似。如果这两个字符串本身是相等的,那它们也是相似的。

例如,"tars""rats" 是相似的 (交换 02 的位置); "rats""arts" 也是相似的,但是 "star" 不与 "tars""rats",或 "arts" 相似。

总之,它们通过相似性形成了两个关联组:{"tars", "rats", "arts"}{"star"}。注意,"tars""arts" 是在同一组中,即使它们并不相似。形式上,对每个组而言,要确定一个单词在组中,只需要这个词和该组中至少一个单词相似。

给你一个字符串列表 strs。列表中的每个字符串都是 strs 中其它所有字符串的一个字母异位词。请问 strs 中有多少个相似字符串组?

示例 1:

输入:strs = ["tars","rats","arts","star"]
输出:2

示例 2:

输入:strs = ["omv","ovm"]
输出:1

提示:

  • 1 <= strs.length <= 300
  • 1 <= strs[i].length <= 300
  • strs[i] 只包含小写字母。
  • strs 中的所有单词都具有相同的长度,且是彼此的字母异位词。

Submission

运行时间: 158 ms

内存: 16.5 MB

class Solution:
    def numSimilarGroups(self, A: List[str]) -> int:
        A = [*{*A}]                             #字符串去重,这个是题目给的坑
        n, m = len(A), len(A[0])
        self.p = [*range(n)]                    #并查集初始化
        self.nmm(A) if n > m * m else self.nnm(A)#选择方案
        return len({*map(self.uni, self.p)})    #并查集去重输出长度
    
    def uni(self, x: int):                      #并查集查询连接函数
        if x != self.p[x]:
            self.p[x] = self.uni(self.p[x])
        return self.p[x]
        
    def nnm(self, A: List[str]):                #O(N^2*M)算法
        n, m = len(A), len(A[0])
        def check(x, y):                        #相似判定函数
            t = 0
            for i in range(m):
                if x[i] != y[i]:
                    t += 1
                    if t > 2:
                        return False
            return True
        for i in range(n):
            for j in range(i + 1, n):           #遍历串的两两组合,然后并查集连接
                pi, pj = self.uni(i), self.uni(j)
                if pi != pj and check(A[i], A[j]):
                    self.p[pj] = pi
        
    def nmm(self, A: List[str]):                #O(N*M^2)算法
        n, m = len(A), len(A[0])
        d = collections.defaultdict(list)       #匹配字典
        e = set()                               #关系集合
        for i, w in enumerate(A):
            for l in range(m):
                for r in range(l + 1, m):       #遍历每个串的两个位置,生成通配串
                    t_w = f'{w[: l]}.{w[l + 1: r]}.{w[r + 1: ]}'
                    if d[t_w]:
                        for j in d[t_w]:        #生成串串关系
                            e |= {(i, j)}
                    d[t_w] += [i]
        for i, j in e:                          #遍历关系集合,然后并查集连接
            pi, pj = self.uni(i), self.uni(j)
            if pi != pi:
                self.p[pj] = pi

# 作者:typingMonkey
# 链接:https://leetcode.cn/problems/similar-string-groups/solutions/22174/839-xiang-si-zi-fu-chuan-zu-pythontai-man-liao-bao/
# 来源:力扣(LeetCode)
# 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

Explain

这个题解使用并查集来解决相似字符串分组的问题。主要思路是: 1. 先对输入的字符串数组去重。 2. 根据输入规模选择两种不同的算法: - 当字符串数量n大于字符串长度平方m^2时,使用O(N*M^2)的算法 - 否则使用O(N^2*M)的算法 3. O(N^2*M)算法通过两两比较字符串的相似性,对相似的字符串进行并查集合并操作 4. O(N*M^2)算法先生成每个字符串的所有通配形式,建立通配串到原字符串的映射,同时记录字符串之间的相似关系,最后对有相似关系的字符串进行并查集合并 5. 最后对并查集的根节点去重并输出集合的大小,即为相似字符串组的数量

时间复杂度: O(min(N^2*M, N*M^2))

空间复杂度: O(N) ~ O(N*M^2)

class Solution:
    def numSimilarGroups(self, A: List[str]) -> int:
        # 字符串去重,这个是题目给的坑
        A = [*{*A}]
        n, m = len(A), len(A[0])
        # 并查集初始化 
        self.p = [*range(n)]
        # 选择方案:当n>m^2时使用O(N*M^2)算法,否则使用O(N^2*M)算法
        self.nmm(A) if n > m * m else self.nnm(A)
        # 对并查集的根节点去重并输出集合大小
        return len({*map(self.uni, self.p)})
    
    # 并查集查询连接函数
    def uni(self, x: int):
        if x != self.p[x]:
            self.p[x] = self.uni(self.p[x])
        return self.p[x]
        
    # O(N^2*M)算法
    def nnm(self, A: List[str]):
        n, m = len(A), len(A[0])
        # 相似判定函数
        def check(x, y):
            t = 0
            for i in range(m):
                if x[i] != y[i]:
                    t += 1
                    if t > 2:
                        return False
            return True
        for i in range(n):
            for j in range(i + 1, n):      
                # 遍历串的两两组合,然后并查集连接
                pi, pj = self.uni(i), self.uni(j)
                if pi != pj and check(A[i], A[j]):
                    self.p[pj] = pi
        
    # O(N*M^2)算法
    def nmm(self, A: List[str]):
        n, m = len(A), len(A[0])
        # 匹配字典
        d = collections.defaultdict(list)
        # 关系集合    
        e = set()
        for i, w in enumerate(A):
            for l in range(m):
                for r in range(l + 1, m):  
                    # 遍历每个串的两个位置,生成通配串
                    t_w = f'{w[: l]}.{w[l + 1: r]}.{w[r + 1: ]}'
                    if d[t_w]:
                        for j in d[t_w]:
                            # 生成串串关系
                            e |= {(i, j)}
                    d[t_w] += [i]
        for i, j in e: 
            # 遍历关系集合,然后并查集连接
            pi, pj = self.uni(i), self.uni(j)
            if pi != pi:
                self.p[pj] = pi

Explore

选择使用O(N^2*M)或O(N*M^2)算法的依据基于输入规模的不同。当字符串数组的长度n大于字符串长度m的平方时,使用O(N*M^2)算法会更有效率,因为这种情况下每个字符串的变化和比较次数相对较少,可以通过生成所有可能的通配形式来快速识别和关联相似字符串。相比之下,如果n小于或等于m^2,O(N^2*M)算法通过直接比较所有字符串对将更有效,因为此时n较小而m可能较大,直接比较成本较低。

并查集的路径压缩是一种优化技术,用于在执行查找操作时减少树的高度,从而提高操作的效率。在查找函数`uni`中,如果当前元素x的父节点不是它自己,即`x != self.p[x]`,则递归调用`uni`函数来找到根节点,并把当前节点的父节点直接设置为根节点。这样,经过几次操作后,树的高度大幅降低,使得后续的查找操作更快。

在题解的O(N*M^2)算法中,`生成每个字符串的所有通配形式`是指生成所有可能的两个字符置换后的字符串形式。具体实现方式为:遍历每个字符串,选择两个不同的位置l和r,交换这两个位置的字符后生成一个新的字符串。例如,字符串`abc`在位置1和3交换后得到`cba`。这些通配形式被用于快速检测和记录不同字符串之间的相似关系,通过一个字典存储通配形式到原始字符串的映射,从而在发现两个相同通配形式的字符串时,可以确认它们是相似的。

相似判定函数`check`的目的是判断两个字符串是否通过最多一次字符交换就可以变得相同。如果在比较过程中发现两个字符串在超过两个位置上的字符不同,那么这意味着无法通过单一交换来让这两个字符串相同,因此直接返回False。这是基于题目中对相似字符串的定义,即两个字符串相似当且仅当它们可以通过交换两个字符变得相同。