统计树中的合法路径数目

标签: 深度优先搜索 数学 动态规划 数论

难度: Hard

给你一棵 n 个节点的无向树,节点编号为 1 到 n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi] 表示节点 ui 和 vi 在树中有一条边。

请你返回树中的 合法路径数目 。

如果在节点 a 到节点 b 之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b) 是 合法的 。

注意:

  • 路径 (a, b) 指的是一条从节点 a 开始到节点 b 结束的一个节点序列,序列中的节点 互不相同 ,且相邻节点之间在树上有一条边。
  • 路径 (a, b) 和路径 (b, a) 视为 同一条 路径,且只计入答案 一次 。

示例 1:

输入:n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]
输出:4
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
只有 4 条合法路径。

示例 2:

输入:n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]
输出:6
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (1, 6) 因为路径 1 到 6 只包含一个质数 3 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
- (3, 6) 因为路径 3 到 6 只包含一个质数 3 。
只有 6 条合法路径。

提示:

  • 1 <= n <= 105
  • edges.length == n - 1
  • edges[i].length == 2
  • 1 <= ui, vi <= n
  • 输入保证 edges 形成一棵合法的树。

Submission

运行时间: 234 ms

内存: 49.3 MB

N = 100001
PRIME = [True] * N
PRIME[1] = False
for i in range(2, N):
    if PRIME[i]:
        for j in range(i * i, N, i):
            PRIME[j] = False
            

class Solution:
    def countPaths(self, n: int, edges: List[List[int]]) -> int:
        UF, rank = list(range(n + 1)), [1] * (n + 1)

        def find(u: int) -> int:
            if UF[u] != u:
                UF[u] = find(UF[u])
            return UF[u]

        def union(u: int, v: int):
            u, v = find(u), find(v)
            if rank[v] < rank[u]:
                u, v = v, u
            UF[u], rank[v] = v, rank[u] + rank[v]

        G = defaultdict(list)
        for u, v in edges:
            if PRIME[u] != PRIME[v]:
                if PRIME[v]:
                    u, v = v, u
                G[u].append(v)
            elif not PRIME[u]:
                union(u, v)

        fn = lambda np: sum(np) + sum(a * b for a, b in zip(np[1:], accumulate(np)))
        return sum(fn([rank[find(k)] for k in cn]) for cn in G.values())

Explain

此题解采用了埃拉托斯特尼筛法来预处理质数,并使用并查集与图的邻接表来识别和处理路径。首先,通过筛法建立一个质数查找表PRIME,以便快速判断节点编号是否为质数。然后,遍历所有边,如果一条边的两个节点一个为质数一个非质数,则在图中添加这条边;如果两个节点都非质数,则通过并查集合并这两个节点。最后,计算图中每个联通分量的路径数量,只考虑含有质数节点的联通分量,使用组合公式计算每个联通分量中任意两点间路径的总数。

时间复杂度: O(N log log N)

空间复杂度: O(N)

# Prime sieve to identify prime numbers
N = 100001
PRIME = [True] * N
PRIME[1] = False
for i in range(2, N):
    if PRIME[i]:
        for j in range(i * i, N, i):
            PRIME[j] = False
            
# Solution class definition
class Solution:
    def countPaths(self, n: int, edges: List[List[int]]) -> int:
        # Union-Find setup
        UF, rank = list(range(n + 1)), [1] * (n + 1)

        # Find function with path compression
        def find(u: int) -> int:
            if UF[u] != u:
                UF[u] = find(UF[u])
            return UF[u]

        # Union function with union by rank
        def union(u: int, v: int):
            u, v = find(u), find(v)
            if rank[v] < rank[u]:
                u, v = v, u
            UF[u], rank[v] = v, rank[u] + rank[v]

        # Building graph where edges connect primes to non-primes
        G = defaultdict(list)
        for u, v in edges:
            if PRIME[u] != PRIME[v]:
                if PRIME[v]:
                    u, v = v, u
                G[u].append(v)
            elif not PRIME[u]:
                union(u, v)

        # Calculate paths in each connected component
        fn = lambda np: sum(np) + sum(a * b for a, b in zip(np[1:], accumulate(np)))
        return sum(fn([rank[find(k)] for k in cn]) for cn in G.values())

Explore

在此问题中,我们需要处理和统计只包含至少一个质数节点的合法路径。并查集(Union-Find)是一种数据结构,用于高效地处理和查询元素间的连通性问题。在本题的上下文中,我们使用并查集来合并所有非质数节点,因为我们只关心包含质数节点的联通分量。当两个节点都是非质数时,我们将它们合并成一个联通分量。这样一来,我们可以忽略纯非质数的联通分量,从而专注于只包含至少一个质数节点的分量。这种方法降低了问题的复杂度,使我们能够更直接地计算结果。

在本题的解法中,我们的目标是找出所有包含至少一个质数节点的合法路径。当一条边连接一个质数节点和一个非质数节点时,这条边可能是连接两个不同质数节点的不同联通分量的桥梁,因此需要添加到图中以便后续的路径计算。如果两个节点都是质数,根据题解中的策略,这种情况并没有明确说明需要特别处理。在实际应用中,可以考虑是否需要将两个质数节点视为一个潜在的合法路径的起点和终点,这取决于具体问题的需求和定义。

埃拉托斯特尼筛法(Sieve of Eratosthenes)是一种高效的算法,用于找出小于或等于某个整数的所有质数。其原理是从2开始,首先标记2的倍数(除了2本身)为非质数,然后找到下一个未被标记的数字(它一定是质数),再标记其所有倍数为非质数。这个过程重复进行,直到达到指定的数。这种方法之所以有效,是因为它从小到大逐步筛除了合数的同时,保留了质数的标记,且每个合数都被其最小的质因数筛除,从而避免了重复工作。

在题解中,lambda函数`fn`被用来计算每个联通分量中任意两点间路径的总数。这个函数首先计算单个节点的贡献(每个节点都可以单独作为一个路径的起点或终点),然后计算两两节点间组合的路径数。具体来说,`fn`函数中的`sum(np)`计算的是每个节点单独作为路径的贡献,而`sum(a * b for a, b in zip(np[1:], accumulate(np)))`计算的是所有可能的节点对组合,其中每对于每对节点 (a, b),其路径数为 a 和 b 的节点数乘积,因为可以从 a 的任意一个节点开始到 b 的任意一个节点结束。通过这种方式,我们能够利用联通分量中的结构信息来快速计算路径总数,从而有效解决问题。