连接所有点的最小费用

标签: 并查集 数组 最小生成树

难度: Medium

给你一个points 数组,表示 2D 平面上的一些点,其中 points[i] = [xi, yi] 。

连接点 [xi, yi] 和点 [xj, yj] 的费用为它们之间的 曼哈顿距离 :|xi - xj| + |yi - yj| ,其中 |val| 表示 val 的绝对值。

请你返回将所有点连接的最小总费用。只有任意两点之间 有且仅有 一条简单路径时,才认为所有点都已连接。

示例 1:

输入:points = [[0,0],[2,2],[3,10],[5,2],[7,0]]
输出:20
解释:

我们可以按照上图所示连接所有点得到最小总费用,总费用为 20 。
注意到任意两个点之间只有唯一一条路径互相到达。

示例 2:

输入:points = [[3,12],[-2,5],[-4,1]]
输出:18

示例 3:

输入:points = [[0,0],[1,1],[1,0],[-1,1]]
输出:4

示例 4:

输入:points = [[-1000000,-1000000],[1000000,1000000]]
输出:4000000

示例 5:

输入:points = [[0,0]]
输出:0

提示:

  • 1 <= points.length <= 1000
  • -106 <= xi, yi <= 106
  • 所有点 (xi, yi) 两两不同。

Submission

运行时间: 102 ms

内存: 17.3 MB

class DisjointSetUnion:
    def __init__(self, n):
        self.n = n
        self.rank = [1] * n
        self.f = list(range(n))
    
    def find(self, x: int) -> int:
        if self.f[x] == x:
            return x
        self.f[x] = self.find(self.f[x])
        return self.f[x]
    
    def unionSet(self, x: int, y: int) -> bool:
        fx, fy = self.find(x), self.find(y)
        if fx == fy:
            return False

        if self.rank[fx] < self.rank[fy]:
            fx, fy = fy, fx
        
        self.rank[fx] += self.rank[fy]
        self.f[fy] = fx
        return True

class BIT:
    def __init__(self, n):
        self.n = n
        self.tree = [float("inf")] * n
        self.idRec = [-1] * n
        self.lowbit = lambda x: x & (-x)
    
    def update(self, pos: int, val: int, identity: int):
        while pos > 0:
            if self.tree[pos] > val:
                self.tree[pos] = val
                self.idRec[pos] = identity
            pos -= self.lowbit(pos)

    def query(self, pos: int) -> int:
        minval, j = float("inf"), -1
        while pos < self.n:
            if minval > self.tree[pos]:
                minval = self.tree[pos]
                j = self.idRec[pos]
            pos += self.lowbit(pos)
        return j

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        n = len(points)
        edges = list()

        def build(pos: List[Tuple[int, int, int]]):
            pos.sort()
            a = [y - x for (x, y, _) in pos]
            b = sorted(set(a))
            num = len(b)

            bit = BIT(num + 1)
            for i in range(n - 1, -1, -1):
                poss = bisect.bisect(b, a[i])
                j = bit.query(poss)
                if j != -1:
                    dis = abs(pos[i][0] - pos[j][0]) + abs(pos[i][1] - pos[j][1])
                    edges.append((dis, pos[i][2], pos[j][2]))
                bit.update(poss, pos[i][0] + pos[i][1], i)
        
        def solve():
            pos = [(x, y, i) for i, (x, y) in enumerate(points)]
            build(pos)
            pos = [(y, x, i) for i, (x, y) in enumerate(points)]
            build(pos)
            pos = [(-y, x, i) for i, (x, y) in enumerate(points)]
            build(pos)
            pos = [(x, -y, i) for i, (x, y) in enumerate(points)]
            build(pos)
        
        solve()
        dsu = DisjointSetUnion(n)
        edges.sort()
        
        ret, num = 0, 1
        for length, x, y in edges:
            if dsu.unionSet(x, y):
                ret += length
                num += 1
                if num == n:
                    break
        
        return ret

Explain

本题解利用了最小生成树的 Kruskal 算法和离散化结合线段树优化的方法来寻找连接所有点的最小费用。首先,通过构建四种变换的点集,以应对曼哈顿距离的计算特点,并通过排序和离散化来优化处理。利用线段树(通过二分索引树BIT实现)维护每个离散化后位置的最小距离和对应索引,来高效地查询和更新边的最小费用。然后,将所有潜在的边加入到边列表中。最后,通过Kruskal算法,利用并查集(Disjoint Set Union, DSU)来确定最小生成树,从而找出所有点连接的最小总费用。

时间复杂度: O(n log n)

空间复杂度: O(n)

class DisjointSetUnion:
    def __init__(self, n):
        self.n = n # 节点数量
        self.rank = [1] * n # 用于优化并查集的rank数组
        self.f = list(range(n)) # 并查集的父节点数组
    
    def find(self, x: int) -> int: # 查找根节点,并进行路径压缩
        if self.f[x] == x:
            return x
        self.f[x] = self.find(self.f[x])
        return self.f[x]
    
    def unionSet(self, x: int, y: int) -> bool: # 合并两个节点
        fx, fy = self.find(x), self.find(y)
        if fx == fy:
            return False
        if self.rank[fx] < self.rank[fy]:
            fx, fy = fy, fx
        self.rank[fx] += self.rank[fy]
        self.f[fy] = fx
        return True

class BIT:
    def __init__(self, n):
        self.n = n # 节点数加1,用于线段树操作
        self.tree = [float("inf")] * n # 线段树存储最小值
        self.idRec = [-1] * n # 记录最小值对应的节点索引
        self.lowbit = lambda x: x & (-x) # 计算最低有效位
    
    def update(self, pos: int, val: int, identity: int): # 更新线段树
        while pos > 0:
            if self.tree[pos] > val:
                self.tree[pos] = val
                self.idRec[pos] = identity
            pos -= self.lowbit(pos)
    
    def query(self, pos: int) -> int: # 查询给定范围内的最小值和索引
        minval, j = float("inf"), -1
        while pos < self.n:
            if minval > self.tree[pos]:
                minval = self.tree[pos]
                j = self.idRec[pos]
            pos += self.lowbit(pos)
        return j

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int: # 主函数
        n = len(points) # 点的数量
        edges = list() # 存储所有边的列表
        def build(pos: List[Tuple[int, int, int]]): # 构建边的函数
            pos.sort() # 对点进行排序
            a = [y - x for (x, y, _) in pos] # 计算曼哈顿距离的一部分
            b = sorted(set(a)) # 离散化处理
            num = len(b) # 离散化后的数量
            bit = BIT(num + 1) # 创建线段树
            for i in range(n - 1, -1, -1): # 逆序处理以维护正确的最小值
                poss = bisect.bisect(b, a[i]) # 二分查找离散化位置
                j = bit.query(poss) # 查询最小值对应的索引
                if j != -1: # 如果找到有效的最小值
                    dis = abs(pos[i][0] - pos[j][0]) + abs(pos[i][1] - pos[j][1]) # 计算曼哈顿距离
                    edges.append((dis, pos[i][2], pos[j][2])) # 将边加入到列表中
                bit.update(poss, pos[i][0] + pos[i][1], i) # 更新线段树
        def solve(): # 处理所有变换并构建边
            pos = [(x, y, i) for i, (x, y) in enumerate(points)] # 原始坐标
            build(pos) # 构建边
            pos = [(y, x, i) for i, (x, y) in enumerate(points)] # 交换x, y以处理不同的曼哈顿距离
            build(pos)
            pos = [(-y, x, i) for i, (x, y) in enumerate(points)] # 反转y坐标
            build(pos)
            pos = [(x, -y, i) for i, (x, y) in enumerate(points)] # 反转x坐标
            build(pos)
        solve() # 调用solve函数构建所有边
        dsu = DisjointSetUnion(n) # 创建并查集实例
        edges.sort() # 根据边的权重排序,以便Kruskal算法处理
        ret, num = 0, 1 # 初始化最小生成树的总权重和计数器
        for length, x, y in edges: # 遍历所有边
            if dsu.unionSet(x, y): # 尝试合并节点
                ret += length # 累加权重
                num += 1 # 增加计数器
                if num == n: # 如果已经连接了所有节点
                    break
        return ret # 返回最小生成树的总权重

Explore

题解中提到的四种变换的点集包括:原始坐标 (x, y)、交换 x 和 y 坐标 (y, x)、反转 y 坐标 (-y, x)、反转 x 坐标 (x, -y)。这些变换是为了解决曼哈顿距离在不同方向上的最小化问题。曼哈顿距离定义为两点间的绝对横向距离与纵向距离之和。通过这些变换,算法能够在各个方向上寻找最小的曼哈顿距离,以确保能够覆盖所有可能的最小距离情况。每种变换使得问题从原始的多维问题简化为一维问题,方便使用线段树进行求解。

在构建边的函数中对点进行逆序处理是为了确保在更新线段树时,能够正确地维护到达每个点的最小可能成本。逆序处理(从后向前遍历点集)可以避免未来的点影响已经处理过的点的最小距离记录。如果从前向后处理,那么在更新线段树时,较后面的点可能会影响到较前面的点的最小值记录,从而使得不能正确记录到达每个点的最小距离。逆序处理确保了每次更新线段树时,只考虑当前点之前的点,从而正确维护各个点的最小距离。

离散化处理的目的是为了将连续或较大范围的数据值映射到较小的、连续的整数索引上,这样可以减小线段树所需处理的数据范围,从而提高效率和简化实现。在题解中,对坐标进行离散化是为了能够在线段树中有效地存储和查询数据,因为线段树处理的是数组索引,而不是直接处理坐标值。通过将坐标值映射到连续的整数索引,可以使用数组直接存取数据,大大提高了空间和时间效率。离散化后,线段树可以在更小的空间内进行操作,并且操作的时间复杂度降低,使得算法整体更加高效。