标签:
难度: Hard
标签:
难度: Hard
运行时间: 1843 ms
内存: 358.3 MB
from functools import cache from collections import Counter from math import perm from typing import List class Solution: def treeOfInfiniteSouls(self, gem: List[int], p: int, target: int) -> int: if p in (2, 5): if target == 9 % p: return perm(len(gem) * 2 - 2, len(gem) - 1) else: return 0 n = len(gem) LIMIT = n - n // 3 r10 = pow(10, p - 2, p) def iter_mask(mask): current = 0 while True: current = ((current | ~mask) + 1) & mask if current == mask: break else: yield current, (~current & mask) @cache def get_length(mask): size = -2 for i, g in enumerate(gem): if (1 << i) & mask: size += 4 + len(str(g)) return size @cache def find_all(mask): if mask.bit_count() == 1: return [int("1" + str(gem[mask.bit_length() - 1]) + "9") % p] result = [] for left, right in iter_mask(mask): left_length = get_length(left) right_length = get_length(right) base = pow(10, left_length + right_length + 1, p) mul = pow(10, right_length + 1, p) for lr in find_all(left): lm = lr * mul + base + 9 for rr in find_all(right): result.append((lm + rr * 10) % p) return result @cache def find_all_counter(mask): return Counter(find_all(mask)) def find_target(mask, target_list): if mask.bit_count() <= LIMIT: return sum(find_all_counter(mask).get(t, 0) for t in target_list) count = 0 for left, right in iter_mask(mask): left_length = get_length(left) right_length = get_length(right) base = pow(10, left_length + right_length + 1, p) if left.bit_count() <= right.bit_count(): mul = pow(10, right_length + 1, p) rr_list = [] for lr in find_all(left): lm = lr * mul + base + 9 rr_list.extend(((t - lm) * r10) % p for t in target_list) count += find_target(right, rr_list) else: mulr = pow(r10, right_length + 1, p) lr_list = [] for rr in find_all(right): rm = base + 9 + rr * 10 lr_list.extend(((t - rm) * mulr) % p for t in target_list) count += find_target(left, lr_list) return count return find_target((1 << n) - 1, [target]) sol = Solution() print(sol.treeOfInfiniteSouls([1, 2, 3], 100, 12319))
This solution utilizes dynamic programming and bitmasking to generate all possible binary tree structures using the gems as leaves and then computes the numeric value generated by these trees. Special care is taken with mod arithmetic to handle the large numbers produced during the calculation. The core idea is to recursively calculate potential results for each subset of gems represented as bit masks, and combine these results according to the binary tree structure to check if any of these results modulo p equals the target. The bitmask represents which gems are included in a subtree, and recursive calls help construct the left and right subtrees, with memorization (caching) used to avoid redundant calculations.
时间复杂度: O(2^n * n)
空间复杂度: O(2^n)
from functools import cache from collections import Counter from math import perm from typing import List class Solution: def treeOfInfiniteSouls(self, gem: List[int], p: int, target: int) -> int: # Special handling for edge cases involving prime factors of 10 if p in (2, 5): if target == 9 % p: return perm(len(gem) * 2 - 2, len(gem) - 1) else: return 0 n = len(gem) LIMIT = n - n // 3 r10 = pow(10, p - 2, p) # Calculate 10^(p-2) mod p to use in reversal # Function to iterate through all possible subsets of mask def iter_mask(mask): current = 0 while True: current = ((current | ~mask) + 1) & mask if current == mask: break else: yield current, (~current & mask) # Calculate the total 'length' of the number formed by nodes in the subtree defined by mask @cache def get_length(mask): size = -2 for i, g in enumerate(gem): if (1 << i) & mask: size += 4 + len(str(g)) return size # Recursively find all possible numbers formed by the tree rooted at a mask @cache def find_all(mask): if mask.bit_count() == 1: return [int("1" + str(gem[mask.bit_length() - 1]) + "9") % p] result = [] for left, right in iter_mask(mask): left_length = get_length(left) right_length = get_length(right) base = pow(10, left_length + right_length + 1, p) mul = pow(10, right_length + 1, p) for lr in find_all(left): lm = lr * mul + base + 9 for rr in find_all(right): result.append((lm + rr * 10) % p) return result # Use Counter to count occurrences of each modulo result @cache def find_all_counter(mask): return Counter(find_all(mask)) # Main function to find how many setups achieve the target modulo def find_target(mask, target_list): if mask.bit_count() <= LIMIT: return sum(find_all_counter(mask).get(t, 0) for t in target_list) count = 0 for left, right in iter_mask(mask): left_length = get_length(left) right_length = get_length(right) base = pow(10, left_length + right_length + 1, p) if left.bit_count() <= right.bit_count(): mul = pow(10, right_length + 1, p) rr_list = [] for lr in find_all(left): lm = lr * mul + base + 9 rr_list.extend(((t - lm) * r10) % p for t in target_list) count += find_target(right, rr_list) else: mulr = pow(r10, right_length + 1, p) lr_list = [] for rr in find_all(right): rm = base + 9 + rr * 10 lr_list.extend(((t - rm) * mulr) % p for t in target_list) count += find_target(left, lr_list) return count return find_target((1 << n) - 1, [target]) sol = Solution() print(sol.treeOfInfiniteSouls([1, 2, 3], 100, 12319))
位掩码(bitmask)表示宝石的子集主要是因为其高效的计算性能和简洁的表达方式。通过位掩码,每个宝石可以通过一个位来表示,其中1代表该宝石被包含在子集中,0则不包含。这种表示方法可以非常方便地通过位运算(如AND, OR, NOT以及XOR)来快速地合并或者查询子集。例如,在生成所有可能的子集以及分割这些子集进行递归计算时,位操作提供了一种非常高效的技术手段。此外,位掩码还可以直接使用整型数进行操作,这在大多数编程语言中都是非常高效的。
在递归函数`find_all`中,基本情况是当子集`mask`只包含一个宝石时。这时,位掩码`mask`中只有一个位是1,表示只有一个宝石被选中。函数通过判断`mask.bit_count() == 1`来识别这种基本情况。对于只有一个宝石的情况,函数会直接返回这个宝石构成的特定格式的数字的模p余数。具体来说,如果宝石的索引是i,那么这个宝石形成的数字格式为`"1" + gem[i] + "9"`,接着将这个字符串转换为整数,然后计算其模p的余数。这个计算结果将作为递归的返回值,用于上层递归中更大子集的计算。
在处理模运算时,计算`10^(p-2) % p`的原因在于需要进行模p运算的逆操作。在某些计算过程中,我们需要得到一个数除以10的模p的结果,这可以通过乘以10的模p逆元来实现。由于10的逆元是`10^(p-1) % p`,而根据费马小定理,当p是质数时,我们有`a^(p-1) % p = 1`,因此`10^(p-2) % p`就是10的模p逆元。在函数`find_target`中,我们需要通过这个逆元来调整数字,以确保在合并左右子树的结果时,能正确地处理除以10的操作,从而维持正确的模p结果。这是处理大数模运算时常用的技巧,特别是在涉及到数字分割和合并的场景中。