熱門文章

2025年12月29日 星期一

ZeroJudge 解題筆記:d481. 矩陣乘法

作者:王一哲
日期:2025年12月29日


ZeroJudge 題目連結:d481. 矩陣乘法

解題想法


按照矩陣乘法的定義,用 3 層 for 迴圈取出對應的元素相乘再加總。

Python 程式碼


使用時間約為 19 ms,記憶體約為 3.4 MB,通過測試。
import sys

def matmul(u, v):
    a, b, d = len(u), len(u[0]), len(v[0])
    w = [[0]*d for _ in range(a)]
    for i in range(a):
        for j in range(d):
            for k in range(b):
                w[i][j] += u[i][k] * v[k][j]
    return w

for line in sys.stdin:
    a, b, c, d = map(int, line.split())
    if b != c:
        print("Error")
        continue
    u = [list(map(int, sys.stdin.readline().split())) for _ in range(a)]
    v = [list(map(int, sys.stdin.readline().split())) for _ in range(c)]
    w = matmul(u, v)
    for row in w: print(*row)


C++ 程式碼


使用時間約為 2 ms,記憶體約為 80 kB,通過測試。
#include <cstdio>
#include <cstring>

int main() {
    long a, b, c, d;
    while(scanf("%ld %ld %ld %ld", &a, &b, &c, &d) != EOF) {
        if (b != c) {
            puts("Error");
            continue;
        }
        long u[a][b], v[c][d], w[a][d];
        memset(w, 0, sizeof(w));
        for(int i=0; i<a; i++) {
            for(int j=0; j<b; j++) {
                scanf("%ld", &u[i][j]);
            }
        }
        for(int i=0; i<c; i++) {
            for(int j=0; j<d; j++) {
                scanf("%ld", &v[i][j]);
            }
        }
        for(int i=0; i<a; i++) {
            for(int j=0; j<d; j++) {
                for(int k=0; k<b; k++) {
                    w[i][j] += u[i][k] * v[k][j];
                }
            }
        }
        for(int i=0; i<a; i++) {
            for(int j=0; j<d-1; j++) printf("%ld ", w[i][j]);
            printf("%ld\n", w[i][d-1]);
        }
    }
    return 0;
}

使用時間約為 2 ms,記憶體約為 284 kB,通過測試。
#include <cstdio>
#include <vector>
using namespace std;

void matmul(vector<vector<long>>& u, vector<vector<long>>& v, vector<vector<long>>& w) {
    size_t a = u.size(), b = u[0].size(), d = v[0].size();
    for(size_t i=0; i<a; i++) {
        for(size_t j=0; j<d; j++) {
            for(size_t k=0; k<b; k++) {
                w[i][j] += u[i][k] * v[k][j];
            }
        }
    }
}

int main() {
    long a, b, c, d;
    while(scanf("%ld %ld %ld %ld", &a, &b, &c, &d) != EOF) {
        if (b != c) {
            puts("Error");
            continue;
        }
        vector<vector<long>> u (a, vector<long> (b)), v (c, vector<long> (d)), w (a, vector<long> (d, 0));
        for(int i=0; i<a; i++) {
            for(int j=0; j<b; j++) {
                scanf("%ld", &u[i][j]);
            }
        }
        for(int i=0; i<c; i++) {
            for(int j=0; j<d; j++) {
                scanf("%ld", &v[i][j]);
            }
        }
        matmul(u, v, w);
        for(int i=0; i<a; i++) {
            for(int j=0; j<d-1; j++) printf("%ld ", w[i][j]);
            printf("%ld\n", w[i][d-1]);
        }
    }
    return 0;
}


沒有留言:

張貼留言