Strassen 矩阵乘法

std::vector<std::vector<int>> strassen(std::vector<std::vector<int> > &A, std::vector<std::vector<int> > &B) {
    int n = A.size();
    std::vector<std::vector<int>> R(n, std::vector<int>(n));

    if (n <= 2) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                R[i][j] = 0;
                for (int k = 0; k < n; k++)
                    R[i][j] += A[i][k] * B[k][j];
            }
        }
        return R;
    }
    else {
        int m = n / 2;
        std::vector<std::vector<int>> A11(m, std::vector<int>(m)), A12(m, std::vector<int>(m)), A21(m, std::vector<int>(m)), A22(m, std::vector<int>(m));
        std::vector<std::vector<int>> B11(m, std::vector<int>(m)), B12(m, std::vector<int>(m)), B21(m, std::vector<int>(m)), B22(m, std::vector<int>(m));

        // 以下步骤分解输入矩阵为小矩阵
        // ...

        // Strassen's recursive calls and additions to compute the 7 products
        // and then calculate the 4 quadrants of the resulting matrix
        // ...

        return R;
    }
}

最后更新: 2024年1月2日 21:37:12
创建日期: 2024年1月2日 19:47:42
回到页面顶部