日期:2025年5月11日
ZeroJudge 題目連結:k854. P8.簡易數織 (Nonogram)
解題想法
這題我試著用 dp 計算可能的組合,但是最後一筆測資極大,用 Python 最後一筆測資超時。後來我把題目丢到 Gemini 2.0 Flash,它提供了一個數學化的解法,才能在時限內順利解出來。
Python 程式碼
最後一筆測資超時。
import sys
def solve(w, n, arr):
mod = 10**9 + 7 # 要取餘數用的整數
imin = sum(arr) + n - 1 # 黑色區塊總和 + 至少需要的空格數量
if imin > w: # 需要的格子數大於 w
print(0); return # 印出 0,跳出函式
dp = [[0]*(n+1) for _ in range(w+1)] # dp[i][j] 填到第 i 格時放入第 j 段黑色區塊
for i in range(w+1): dp[i][0] = 1 # 初始化,填滿 0 ~ w 格、沒有黑色區塊,只有一種方式
for i in range(1, w+1): # 填滿 1 ~ w 格
for j in range(1, n+1): # 放入第 j 段黑色區塊
block = arr[j-1] # 黑色區塊長度
dp[i][j] = (dp[i][j] + dp[i-1][j]) % mod # 不放這個黑色區塊
if i >= block: # 可以放入這個黑色區塊
if j == 1: # 第 1 個黑色區塊,前面可以不需要是空格
dp[i][j] = (dp[i][j] + dp[i-block][j-1]) % mod
elif i >= block+1: # 不是第 1 個黑色區塊,前面必須是空格
dp[i][j] = (dp[i][j] + dp[i-block-1][j-1]) % mod
print(dp[w][n]); return
for line in sys.stdin:
w, n = map(int, line.split()) # w 個格子,n 段連續的黑色區塊
arr = list(map(int, input().split())) # 各段連續黑色區塊的長度
solve(w, n, arr)
C++ 程式碼
使用時間約為 13 ms,記憶體約為 11.8 MB,通過測試。
#include <cstdio>
#include <vector>
typedef long long LL;
using namespace std;
void solve(LL w, LL n, LL* arr) {
const LL mod = 1e9 + 7; // 要取餘數用的整數
LL imin = n - 1; // 黑色區塊總和 + 至少需要的空格數量
for(int i=0; i<n; i++) imin += arr[i];
if (imin > w) { // 需要的格子數大於 w
printf("0\n"); return; // 印出 0,跳出函式
}
vector<vector<LL>> dp (w+1, vector<LL> (n+1)); // dp[i][j] 填到第 i 格時放入第 j 段黑色區塊
for(LL i=0; i<=w; i++) dp[i][0] = 1; // 初始化,填滿 0 ~ w 格、沒有黑色區塊,只有一種方式
for(LL i=1; i<=w; i++) { // 填滿 1 ~ w 格
for(LL j=1; j<=n; j++) { // 放入第 j 段黑色區塊
LL block = arr[j-1]; // 黑色區塊長度
dp[i][j] = (dp[i][j] + dp[i-1][j]) % mod; // 不放這個黑色區塊
if (i >= block) { // 可以放入這個黑色區塊
if (j == 1) { // 第 1 個黑色區塊,前面可以不需要是空格
dp[i][j] = (dp[i][j] + dp[i-block][j-1]) % mod;
} else if (i >= block+1) { // 不是第 1 個黑色區塊,前面必須是空格
dp[i][j] = (dp[i][j] + dp[i-block-1][j-1]) % mod;
}
}
}
}
printf("%lld\n", dp[w][n]); return;
}
int main() {
LL w, n; // w 個格子,n 段連續的黑色區塊
while(scanf("%lld %lld", &w, &n) != EOF) {
LL arr[n]; // 各段連續黑色區塊的長度
for(LL i=0; i<n; i++) scanf("%lld", &arr[i]);
solve(w, n, arr);
}
return 0;
}
數學解法
使用時間約為 19 ms,記憶體約為 3.4 MB,通過測試。
def solve_nonogram():
w, n = map(int, input().split())
a = list(map(int, input().split()))
sum_a = sum(a)
remaining_white_space = w - sum_a - (n - 1)
if remaining_white_space < 0:
print(0)
return
r = remaining_white_space
def nCr_mod_p(n, r, p):
if r < 0 or r > n:
return 0
if r == 0 or r == n:
return 1
if r > n // 2:
r = n - r
num = 1
den = 1
for i in range(r):
num = (num * (n - i)) % p
den = (den * (i + 1)) % p
return (num * pow(den, p - 2, p)) % p
mod = 10**9 + 7
result = nCr_mod_p(r + n, n, mod)
print(result)
if __name__ == '__main__':
solve_nonogram()
沒有留言:
張貼留言