日期:2026年6月12日
LeetCode 題目連結:3559. Number of Ways to Assign Edge Weights II
解題想法
困難題。由於節點數量 $n$ 可能多達 $10^5$ 個,用暴力法找出任意兩個節點之間的距離一定會超時,需要利用倍增法 (binary lift) 找出要查詢的兩個節點的最近共同祖先 (lowest common ancestor, LCA),主要分成以下的步驟:
- 用接鄰矩陣存圖
- 先開好需要用到的串列,例如儲存每個節點深度用的 $depth$,儲存節點 $i$ 的第 $2^j$ 代祖先的二維串列 $up$。
- 用廣度優先搜尋 (BFS) 初始化 depth 和第 1 代祖先 up[i][0]
- 建立倍增表 $up$,其中 $up[i][j]$ 代表節點 $i$ 的第 $2^j$ 代祖先,由於 $2^j = 2^{j-1} \times 2^{j-1}$,因此建表時 $up[i][j] = up[up[i][j-1]][j-1]$。
- 寫一個查詢最近共同祖先的自訂函式 get_lca,找出代入的節點 $u, v$ 的最近共同祖先。
- 假設要查詢的節點為 $u, v$,最短距離 $dist = depth[u] + depth[v] - 2 \times depth[lca]$,答案為 $(2^{dist - 1}) % 1000000007$。
Python 程式碼
Runtime: 1463 ms, beats 35.85%. Memory: 99.75 MB, beats 88.68%.
from collections import deque
class Solution:
def assignEdgeWeights(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
# --- 用接鄰矩陣存圖 ---
n = len(edges) + 1
adj = [[] for _ in range(n+1)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
# --- 用倍增法找節點 i 的第 2**j 代祖先
LOG = 20 # 2**17 > 10**5,取 20 一定夠用
depth = [0] * (n+1) # 節點的深度
up = [[1] * LOG for _ in range(n+1)] # 儲存節點 i 的第 2**j 代祖先
# --- 用 BFS 初始化 depth 和第 1 代祖先 up[i][0]
que = deque([1])
visited = {1}
depth[1] = 0
while que:
u = que.popleft()
for v in adj[u]:
if v not in visited:
visited.add(v)
depth[v] = depth[u] + 1
up[v][0] = u # v 的父節點為 u
que.append(v)
# --- 建立倍增表 (binary lift)
# up[i][j] 為 i 的第 2**(j-1) 代祖先的第 2**(j-1) 代祖先
for j in range(1, LOG): # 控制代數的 j 要放外㽪
for i in range(1, n+1):
up[i][j] = up[up[i][j-1]][j-1] # 因為 2**j = 2**(j-1) * 2**(j-1)
# --- 查詢最近共同祖先 (lowest common ancestor, LCA)
def get_lca(u, v):
# 防呆,確保 u 是比較深的節點
if depth[u] < depth[v]:
u, v = v, u
# 將 u 往上提,直到跟 v 位於同一深度
diff = depth[u] - depth[v]
for j in range(LOG):
if (diff >> j) & 1:
u = up[u][j]
# 如果 u 往上提之後等於 u,直接回傳 u
if u == v: return u
# u 和 v 一起往上提,直到找到 LCA 的子節點
for j in range(LOG - 1, -1, -1):
if up[u][j] != up[v][j]:
u = up[u][j]
v = up[v][j]
return up[u][0] # 回傳 u 的父節點
# --- 處理所有查詢 ---
MOD = 10**9 + 7
ans = []
for u, v in queries:
lca = get_lca(u, v)
# 計算兩點在樹上的距離,邊的數量
dist = depth[u] + depth[v] - 2 * depth[lca]
if dist == 0:
ans.append(0)
else:
ans.append(pow(2, dist - 1, MOD))
return ans
C++ 程式碼
Runtime: 876 ms, beats 5.56%. Memory: 447.40 MB, beats 5.56%.
class Solution {
public:
const int LOG = 20; // 2**17 > 10**5,取 20 一定夠用
const int MOD = 1000000007;
int n; // 節點數量
vector<vector<int>> adj, up; // 接鄰矩陣,節點 i 的第 2**j 代祖先
vector<int> depth; // 節點的深度
// 快速冪,計算 a**b % MOD
long mypow(long a, long b) {
long r = 1;
while(b) {
if (b&1) r = (r * a) % MOD;
a = a * a % MOD;
b >>= 1;
}
return r;
}
// 查詢最近共同祖先 (lowest common ancestor, LCA)
int get_lca(int u, int v) {
// 防呆,確保 u 是比較深的節點
if (depth[u] < depth[v]) swap(u, v);
// 將 u 往上提,直到跟 v 位於同一深度
int diff = depth[u] - depth[v];
for(int j = 0; j < LOG; j++) {
if ((diff >> j) & 1) u = up[u][j];
}
// 如果 u 往上提之後等於 u,直接回傳 u
if (u == v) return u;
// u 和 v 一起往上提,直到找到 LCA 的子節點
for(int j = LOG - 1; j >= 0; j--) {
if (up[u][j] != up[v][j]) {
u = up[u][j];
v = up[v][j];
}
}
return up[u][0]; // 回傳 u 的父節點
}
vector<int> assignEdgeWeights(vector<vector<int>>& edges, vector<vector<int>>& queries) {
/* 用接鄰矩陣存圖 */
n = (int)edges.size() + 1;
adj.assign(n+1, vector<int> (0));
for(auto it : edges) {
int u = it[0], v = it[1];
adj[u].push_back(v);
adj[v].push_back(u);
}
/* 用倍增法找節點 i 的第 2**j 代祖先 */
depth.assign(n+1, 0); // 節點的深度
up.assign(n+1, vector<int> (LOG, 1)); // 儲存節點 i 的第 2**j 代祖先
/* 用 BFS 初始化 depth 和第 1 代祖先 (up[i][0]) */
queue<int> que;
que.push(1);
set<int> visited = {1};
depth[1] = 0;
while(!que.empty()) {
int u = que.front();
que.pop();
for(int v : adj[u]) {
if (visited.count(v) == 0) {
visited.insert(v);
depth[v] = depth[u] + 1;
up[v][0] = u; // v 的父節點為 u
que.push(v);
}
}
}
/* 建立倍增表 (binary lift) */
// up[i][j] 為 i 的第 2**(j-1) 代祖先的第 2**(j-1) 代祖先
for(int j = 1; j < LOG; j++) { // 控制代數的 j 要放外㽪
for(int i = 1; i <= n; i++) {
up[i][j] = up[up[i][j-1]][j-1]; // 因為 2**j = 2**(j-1) * 2**(j-1)
}
}
/* 處理所有查詢 */
vector<int> ans;
for(auto it : queries) {
int u = it[0], v = it[1];
int lca = get_lca(u, v);
// 計算兩點在樹上的距離,邊的數量
int dist = depth[u] + depth[v] - 2 * depth[lca];
if (dist == 0) {
ans.push_back(0);
} else {
ans.push_back(mypow(2, dist - 1));
}
}
return ans;
}
};
沒有留言:
張貼留言