万灵之树

标签:

难度: Hard

探险家小扣终于来到了万灵之树前,挑战最后的谜题。 已知小扣拥有足够数量的链接节点和 `n` 颗幻境宝石,`gem[i]` 表示第 `i` 颗宝石的数值。现在小扣需要使用这些链接节点和宝石组合成一颗二叉树,其组装规则为: - 链接节点将作为二叉树中的非叶子节点,且每个链接节点必须拥有 `2` 个子节点; - 幻境宝石将作为二叉树中的叶子节点,所有的幻境宝石都必须被使用。 能量首先进入根节点,而后将按如下规则进行移动和记录: - 若能量首次到达该节点时: - 记录数字 `1`; - 若该节点为叶节点,将额外记录该叶节点的数值; - 若存在未到达的子节点,则选取未到达的一个子节点(优先选取左子节点)进入; - 若无子节点或所有子节点均到达过,此时记录 `9`,并回到当前节点的父节点(若存在)。 如果最终记下的数依序连接成一个整数 `num`,满足 $num \mod~p=target$,则视为解开谜题。 请问有多少种二叉树的组装方案,可以使得最终记录下的数字可以解开谜题 **注意:** - 两棵结构不同的二叉树,作为不同的组装方案 - 两棵结构相同的二叉树且存在某个相同位置处的宝石编号不同,也作为不同的组装方案 - 可能存在数值相同的两颗宝石 **示例 1:** > 输入:`gem = [2,3]` > `p = 100000007` > `target = 11391299` > > 输出:`1` > > 解释: > 包含 `2` 个叶节点的结构只有一种。 > 假设 B、C 节点的值分别为 3、2,对应 target 为 11391299,如下图所示。 > 11391299 % 100000007 = 11391299,满足条件; > 假设 B、C 节点的值分别为 2、3,对应 target 为 11291399; > 11291399 % 100000007 = 11291399,不满足条件; > 因此只存在 1 种方案,返回 1 ![万灵 (1).gif](https://pic.leetcode.cn/1682397079-evMssw-%E4%B8%87%E7%81%B5%20\(1\).gif){:height=300px} **示例 2:** > 输入:`gem = [3,21,3]` > `p = 7` > `target = 5` > > 输出:`4` > > 解释: 包含 `3` 个叶节点树结构有两种,列举如下: 满足条件的组合有四种情况: > 当结构为下图(1)时:叶子节点的值为 [3,3,21] 或 [3,3,21],得到的整数为 `11139139912199`。 > 当结构为下图(2)时:叶子节点的值为 [21,3,3] 或 [21,3,3],得到的整数为 `11219113913999`。 ![image.png](https://pic.leetcode.cn/1682322894-vfqJIV-image.png){:width=500px} **提示:** - `1 <= gem.length <= 9` - `0 <= gem[i] <= 10^9` - `1 <= p <= 10^9`,保证 $p$ 为素数。 - `0 <= target < p` - 存在 2 组 `gem.length == 9` 的用例

Submission

运行时间: 1843 ms

内存: 358.3 MB

from functools import cache
from collections import Counter
from math import perm
from typing import List


class Solution:
    def treeOfInfiniteSouls(self, gem: List[int], p: int, target: int) -> int:
        if p in (2, 5):
            if target == 9 % p:
                return perm(len(gem) * 2 - 2, len(gem) - 1)
            else:
                return 0
        n = len(gem)
        LIMIT = n - n // 3
        r10 = pow(10, p - 2, p)

        def iter_mask(mask):
            current = 0
            while True:
                current = ((current | ~mask) + 1) & mask
                if current == mask:
                    break
                else:
                    yield current, (~current & mask)

        @cache
        def get_length(mask):
            size = -2
            for i, g in enumerate(gem):
                if (1 << i) & mask:
                    size += 4 + len(str(g))
            return size

        @cache
        def find_all(mask):
            if mask.bit_count() == 1:
                return [int("1" + str(gem[mask.bit_length() - 1]) + "9") % p]
            result = []
            for left, right in iter_mask(mask):
                left_length = get_length(left)
                right_length = get_length(right)
                base = pow(10, left_length + right_length + 1, p)
                mul = pow(10, right_length + 1, p)
                for lr in find_all(left):
                    lm = lr * mul + base + 9
                    for rr in find_all(right):
                        result.append((lm + rr * 10) % p)
            return result

        @cache
        def find_all_counter(mask):
            return Counter(find_all(mask))

        def find_target(mask, target_list):
            if mask.bit_count() <= LIMIT:
                return sum(find_all_counter(mask).get(t, 0) for t in target_list)
            count = 0
            for left, right in iter_mask(mask):
                left_length = get_length(left)
                right_length = get_length(right)
                base = pow(10, left_length + right_length + 1, p)
                if left.bit_count() <= right.bit_count():
                    mul = pow(10, right_length + 1, p)
                    rr_list = []
                    for lr in find_all(left):
                        lm = lr * mul + base + 9
                        rr_list.extend(((t - lm) * r10) % p for t in target_list)
                    count += find_target(right, rr_list)
                else:
                    mulr = pow(r10, right_length + 1, p)
                    lr_list = []
                    for rr in find_all(right):
                        rm = base + 9 + rr * 10
                        lr_list.extend(((t - rm) * mulr) % p for t in target_list)
                    count += find_target(left, lr_list)
            return count

        return find_target((1 << n) - 1, [target])


sol = Solution()
print(sol.treeOfInfiniteSouls([1, 2, 3], 100, 12319))

Explain

This solution utilizes dynamic programming and bitmasking to generate all possible binary tree structures using the gems as leaves and then computes the numeric value generated by these trees. Special care is taken with mod arithmetic to handle the large numbers produced during the calculation. The core idea is to recursively calculate potential results for each subset of gems represented as bit masks, and combine these results according to the binary tree structure to check if any of these results modulo p equals the target. The bitmask represents which gems are included in a subtree, and recursive calls help construct the left and right subtrees, with memorization (caching) used to avoid redundant calculations.

时间复杂度: O(2^n * n)

空间复杂度: O(2^n)

from functools import cache
from collections import Counter
from math import perm
from typing import List

class Solution:
    def treeOfInfiniteSouls(self, gem: List[int], p: int, target: int) -> int:
        # Special handling for edge cases involving prime factors of 10
        if p in (2, 5):
            if target == 9 % p:
                return perm(len(gem) * 2 - 2, len(gem) - 1)
            else:
                return 0
        n = len(gem)
        LIMIT = n - n // 3
        r10 = pow(10, p - 2, p)  # Calculate 10^(p-2) mod p to use in reversal

        # Function to iterate through all possible subsets of mask
        def iter_mask(mask):
            current = 0
            while True:
                current = ((current | ~mask) + 1) & mask
                if current == mask:
                    break
                else:
                    yield current, (~current & mask)

        # Calculate the total 'length' of the number formed by nodes in the subtree defined by mask
        @cache
        def get_length(mask):
            size = -2
            for i, g in enumerate(gem):
                if (1 << i) & mask:
                    size += 4 + len(str(g))
            return size

        # Recursively find all possible numbers formed by the tree rooted at a mask
        @cache
        def find_all(mask):
            if mask.bit_count() == 1:
                return [int("1" + str(gem[mask.bit_length() - 1]) + "9") % p]
            result = []
            for left, right in iter_mask(mask):
                left_length = get_length(left)
                right_length = get_length(right)
                base = pow(10, left_length + right_length + 1, p)
                mul = pow(10, right_length + 1, p)
                for lr in find_all(left):
                    lm = lr * mul + base + 9
                    for rr in find_all(right):
                        result.append((lm + rr * 10) % p)
            return result

        # Use Counter to count occurrences of each modulo result
        @cache
        def find_all_counter(mask):
            return Counter(find_all(mask))

        # Main function to find how many setups achieve the target modulo
        def find_target(mask, target_list):
            if mask.bit_count() <= LIMIT:
                return sum(find_all_counter(mask).get(t, 0) for t in target_list)
            count = 0
            for left, right in iter_mask(mask):
                left_length = get_length(left)
                right_length = get_length(right)
                base = pow(10, left_length + right_length + 1, p)
                if left.bit_count() <= right.bit_count():
                    mul = pow(10, right_length + 1, p)
                    rr_list = []
                    for lr in find_all(left):
                        lm = lr * mul + base + 9
                        rr_list.extend(((t - lm) * r10) % p for t in target_list)
                    count += find_target(right, rr_list)
                else:
                    mulr = pow(r10, right_length + 1, p)
                    lr_list = []
                    for rr in find_all(right):
                        rm = base + 9 + rr * 10
                        lr_list.extend(((t - rm) * mulr) % p for t in target_list)
                    count += find_target(left, lr_list)
            return count

        return find_target((1 << n) - 1, [target])

sol = Solution()
print(sol.treeOfInfiniteSouls([1, 2, 3], 100, 12319))

Explore

位掩码(bitmask)表示宝石的子集主要是因为其高效的计算性能和简洁的表达方式。通过位掩码,每个宝石可以通过一个位来表示,其中1代表该宝石被包含在子集中,0则不包含。这种表示方法可以非常方便地通过位运算(如AND, OR, NOT以及XOR)来快速地合并或者查询子集。例如,在生成所有可能的子集以及分割这些子集进行递归计算时,位操作提供了一种非常高效的技术手段。此外,位掩码还可以直接使用整型数进行操作,这在大多数编程语言中都是非常高效的。

在递归函数`find_all`中,基本情况是当子集`mask`只包含一个宝石时。这时,位掩码`mask`中只有一个位是1,表示只有一个宝石被选中。函数通过判断`mask.bit_count() == 1`来识别这种基本情况。对于只有一个宝石的情况,函数会直接返回这个宝石构成的特定格式的数字的模p余数。具体来说,如果宝石的索引是i,那么这个宝石形成的数字格式为`"1" + gem[i] + "9"`,接着将这个字符串转换为整数,然后计算其模p的余数。这个计算结果将作为递归的返回值,用于上层递归中更大子集的计算。

在处理模运算时,计算`10^(p-2) % p`的原因在于需要进行模p运算的逆操作。在某些计算过程中,我们需要得到一个数除以10的模p的结果,这可以通过乘以10的模p逆元来实现。由于10的逆元是`10^(p-1) % p`,而根据费马小定理,当p是质数时,我们有`a^(p-1) % p = 1`,因此`10^(p-2) % p`就是10的模p逆元。在函数`find_target`中,我们需要通过这个逆元来调整数字,以确保在合并左右子树的结果时,能正确地处理除以10的操作,从而维持正确的模p结果。这是处理大数模运算时常用的技巧,特别是在涉及到数字分割和合并的场景中。