日期:2026年6月12日
LeetCode 題目連結:3559. Number of Ways to Assign Edge Weights II
解題想法
困難題。由於節點數量 $n$ 可能多達 $10^5$ 個,用暴力法找出任意兩個節點之間的距離一定會超時,需要利用倍增法 (binary lift) 找出要查詢的兩個節點的最近共同祖先 (lowest common ancestor, LCA),主要分成以下的步驟:
- 用接鄰矩陣存圖
- 先開好需要用到的串列,例如儲存每個節點深度用的 $depth$,儲存節點 $i$ 的第 $2^j$ 代祖先的二維串列 $up$。
- 用廣度優先搜尋 (BFS) 初始化 depth 和第 1 代祖先 up[i][0]
- 建立倍增表 $up$,其中 $up[i][j]$ 代表節點 $i$ 的第 $2^j$ 代祖先,由於 $2^j = 2^{j-1} \times 2^{j-1}$,因此建表時 $up[i][j] = up[up[i][j-1]][j-1]$。
- 寫一個查詢最近共同祖先的自訂函式 get_lca,找出代入的節點 $u, v$ 的最近共同祖先。
- 假設要查詢的節點為 $u, v$,最短距離 $dist = depth[u] + depth[v] - 2 \times depth[lca]$,答案為 $(2^{dist - 1}) % 1000000007$。
Python 程式碼
Runtime: 1463 ms, beats 35.85%. Memory: 99.75 MB, beats 88.68%.
from collections import deque
class Solution:
def assignEdgeWeights(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
# --- 用接鄰矩陣存圖 ---
n = len(edges) + 1
adj = [[] for _ in range(n+1)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
# --- 用倍增法找節點 i 的第 2**j 代祖先
LOG = 20 # 2**17 > 10**5,取 20 一定夠用
depth = [0] * (n+1) # 節點的深度
up = [[1] * LOG for _ in range(n+1)] # 儲存節點 i 的第 2**j 代祖先
# --- 用 BFS 初始化 depth 和第 1 代祖先 up[i][0]
que = deque([1])
visited = {1}
depth[1] = 0
while que:
u = que.popleft()
for v in adj[u]:
if v not in visited:
visited.add(v)
depth[v] = depth[u] + 1
up[v][0] = u # v 的父節點為 u
que.append(v)
# --- 建立倍增表 (binary lift)
# up[i][j] 為 i 的第 2**(j-1) 代祖先的第 2**(j-1) 代祖先
for j in range(1, LOG): # 控制代數的 j 要放外㽪
for i in range(1, n+1):
up[i][j] = up[up[i][j-1]][j-1] # 因為 2**j = 2**(j-1) * 2**(j-1)
# --- 查詢最近共同祖先 (lowest common ancestor, LCA)
def get_lca(u, v):
# 防呆,確保 u 是比較深的節點
if depth[u] < depth[v]:
u, v = v, u
# 將 u 往上提,直到跟 v 位於同一深度
diff = depth[u] - depth[v]
for j in range(LOG):
if (diff >> j) & 1:
u = up[u][j]
# 如果 u 往上提之後等於 u,直接回傳 u
if u == v: return u
# u 和 v 一起往上提,直到找到 LCA 的子節點
for j in range(LOG - 1, -1, -1):
if up[u][j] != up[v][j]:
u = up[u][j]
v = up[v][j]
return up[u][0] # 回傳 u 的父節點
# --- 處理所有查詢 ---
MOD = 10**9 + 7
ans = []
for u, v in queries:
lca = get_lca(u, v)
# 計算兩點在樹上的距離,邊的數量
dist = depth[u] + depth[v] - 2 * depth[lca]
if dist == 0:
ans.append(0)
else:
ans.append(pow(2, dist - 1, MOD))
return ans