统计区间中的整数数目

标签: 设计 线段树 有序集合

难度: Hard

给你区间的 集,请你设计并实现满足要求的数据结构:

  • 新增:添加一个区间到这个区间集合中。
  • 统计:计算出现在 至少一个 区间中的整数个数。

实现 CountIntervals 类:

  • CountIntervals() 使用区间的空集初始化对象
  • void add(int left, int right) 添加区间 [left, right] 到区间集合之中。
  • int count() 返回出现在 至少一个 区间中的整数个数。

注意:区间 [left, right] 表示满足 left <= x <= right 的所有整数 x

示例 1:

输入
["CountIntervals", "add", "add", "count", "add", "count"]
[[], [2, 3], [7, 10], [], [5, 8], []]
输出
[null, null, null, 6, null, 8]

解释
CountIntervals countIntervals = new CountIntervals(); // 用一个区间空集初始化对象
countIntervals.add(2, 3);  // 将 [2, 3] 添加到区间集合中
countIntervals.add(7, 10); // 将 [7, 10] 添加到区间集合中
countIntervals.count();    // 返回 6
                           // 整数 2 和 3 出现在区间 [2, 3] 中
                           // 整数 7、8、9、10 出现在区间 [7, 10] 中
countIntervals.add(5, 8);  // 将 [5, 8] 添加到区间集合中
countIntervals.count();    // 返回 8
                           // 整数 2 和 3 出现在区间 [2, 3] 中
                           // 整数 5 和 6 出现在区间 [5, 8] 中
                           // 整数 7 和 8 出现在区间 [5, 8] 和区间 [7, 10] 中
                           // 整数 9 和 10 出现在区间 [7, 10] 中

提示:

  • 1 <= left <= right <= 109
  • 最多调用  addcount 方法 总计 105
  • 调用 count 方法至少一次

Submission

运行时间: 416 ms

内存: 57.6 MB

class Node:
    def __init__(self, left, right):
        self.left = left
        self.right = right
        self.prev = None
        self.next = None

class CountIntervals:
    def __init__(self):
        self.head = Node(0, 0)
        self.tail = Node(0, 0)
        self.head.next, self.tail.prev =  self.tail, self.head
        self.cnt = 0

    def add(self, left: int, right: int) -> None:
        if left > self.tail.prev.right:
            self.insert(self.tail.prev, Node(left, right))
            self.cnt += right - left + 1
            return
        p = self.head
        while p.next != self.tail and p.next.right < left:
            p = p.next
        # print(f"add, p got to {p.left} {p.right}")
        # self.print_all()
        if p.next == self.tail or p.next.left > right:
            self.insert(p, Node(left, right))
            self.cnt += right - left + 1
        else:
            if left < p.next.left:
                self.cnt += p.next.left - left
                p.next.left = left
            if p.next.right < right:
                self.update(p.next, right)
    
    def update(self, node, right):
        p = node
        # self.print_all()
        while p.next != self.tail and right >= p.next.left:
            p = p.next
            self.cnt -= p.right - p.left + 1
        right = max(right, p.right)
        self.cnt += right - node.right
        node.right = right
        node.next = p.next
        node.next.prev = node
    
    def insert(self, p, node):
        node.next = p.next
        node.next.prev = node
        p.next = node
        node.prev = p

    def count(self) -> int:
        return self.cnt

Explain

题解采用了双向链表来维护区间集合,每个节点代表一个区间。每次添加新区间时,会尝试与现有区间合并以避免重叠,并维护一个计数器cnt来统计所有涵盖的整数数量。具体操作包括:1. 如果新区间在所有现有区间的右边,则直接添加,并更新计数器。2. 如果新区间和现有区间有重叠,则进行合并,并适当地更新计数器以反映合并后的区间的变化。3. 使用双向链表可以方便地插入和删除节点,同时更新前后节点的链接。

时间复杂度: O(n)

空间复杂度: O(n)


class Node:
    def __init__(self, left, right):
        self.left = left  # 区间左端点
        self.right = right  # 区间右端点
        self.prev = None  # 指向前一个节点的指针
        self.next = None  # 指向下一个节点的指针

class CountIntervals:
    def __init__(self):
        self.head = Node(0, 0)  # 虚拟头节点
        self.tail = Node(0, 0)  # 虚拟尾节点
        self.head.next, self.tail.prev =  self.tail, self.head
        self.cnt = 0  # 维护区间涵盖的整数数量

    def add(self, left: int, right: int) -> None:
        if left > self.tail.prev.right:  # 如果新区间在所有区间的右侧
            self.insert(self.tail.prev, Node(left, right))
            self.cnt += right - left + 1
            return
        p = self.head
        while p.next != self.tail and p.next.right < left:
            p = p.next
        if p.next == self.tail or p.next.left > right:
            self.insert(p, Node(left, right))
            self.cnt += right - left + 1
        else:
            if left < p.next.left:
                self.cnt += p.next.left - left
                p.next.left = left
            if p.next.right < right:
                self.update(p.next, right)
    
    def update(self, node, right):
        p = node
        while p.next != self.tail and right >= p.next.left:
            p = p.next
            self.cnt -= p.right - p.left + 1
        right = max(right, p.right)
        self.cnt += right - node.right
        node.right = right
        node.next = p.next
        node.next.prev = node
    
    def insert(self, p, node):
        node.next = p.next
        node.next.prev = node
        p.next = node
        node.prev = p

    def count(self) -> int:
        return self.cnt  # 返回当前总数量

Explore

如果新添加的区间完全嵌套在现有区间中,例如添加[4,5]到已有的[1,10]区间,实际上不需要对现有区间进行任何修改,也不影响整体的整数数目。在题解的算法中,当寻找到一个现有区间,该区间的左端点小于或等于新增区间的左端点,并且其右端点大于或等于新增区间的右端点时,就可以确定新区间已被现有区间完全包含,因此无需进行插入或计数器更新。这种情况在题解中没有直接体现,可能需要增加逻辑来处理这种特殊情况,避免不必要的合并或区间更新操作。

在`update`函数中,通过从当前节点开始向后遍历,检查每个相邻区间是否与新区间有重叠或相邻。如果有,就将这些区间合并到当前节点中,并更新计数器来反映合并后的整数数目的变化。该函数会减掉所有被合并区间的整数数目,并添加新的合并后的区间的整数数目。遍历继续直到找到一个区间的左端点大于新区间的右端点,确保所有重叠或相邻的区间都被合并。这种方法确保了合并操作的完整性并防止了区间的遗漏。

是的,如果新区间的左端点`left`恰好等于最后一个区间的右端点加一(`self.tail.prev.right + 1`),则应该将这两个区间合并,而不是创建一个新的区间。这种情况在题解算法中未明确处理,需要修改逻辑以检测这种边界条件并进行合并,从而维护区间的连续性和减少不必要的区间分割。

从当前节点开始向后遍历的原因是,新区间与当前节点已经有了重叠,表明合并应该从此处开始。从头开始遍历虽然可以处理所有情况,但效率较低,因为它不利用到当前位置的信息。对于优化策略,可以考虑使用平衡树(如红黑树)或区间树,这些数据结构可以更高效地处理区间的插入、删除和合并,尤其是在处理大量区间操作时可以显著提高效率。