From 3e55d21feb447ef194cc45e93ef31454ad7498ac Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 19:23:22 +0530 Subject: [PATCH 1/8] feat: Strassen's matrix multiplication algorithm added --- matrix/strassen_matrix_multiplication.py | 132 +++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 matrix/strassen_matrix_multiplication.py diff --git a/matrix/strassen_matrix_multiplication.py b/matrix/strassen_matrix_multiplication.py new file mode 100644 index 000000000000..708a7bbce9dc --- /dev/null +++ b/matrix/strassen_matrix_multiplication.py @@ -0,0 +1,132 @@ +from typing import List + +Matrix = List[List[int]] + +def add(A: Matrix, B: Matrix) -> Matrix: + n = len(A) + return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)] + +def sub(A: Matrix, B: Matrix) -> Matrix: + n = len(A) + return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)] + +def naive_mul(A: Matrix, B: Matrix) -> Matrix: + n = len(A) + C = [[0]*n for _ in range(n)] + for i in range(n): + ai = A[i] + ci = C[i] + for k in range(n): + a_ik = ai[k] + bk = B[k] + for j in range(n): + ci[j] += a_ik * bk[j] + return C + +def next_power_of_two(n: int) -> int: + p = 1 + while p < n: + p <<= 1 + return p + +def pad_matrix(A: Matrix, size: int) -> Matrix: + n = len(A) + padded = [[0]*size for _ in range(size)] + for i in range(n): + for j in range(len(A[0])): + padded[i][j] = A[i][j] + return padded + +def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix: + return [row[:cols] for row in A[:rows]] + +def split(A: Matrix) -> tuple: + n = len(A) + mid = n // 2 + A11 = [[A[i][j] for j in range(mid)] for i in range(mid)] + A12 = [[A[i][j] for j in range(mid, n)] for i in range(mid)] + A21 = [[A[i][j] for j in range(mid)] for i in range(mid, n)] + A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)] + return A11, A12, A21, A22 + +def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix: + n2 = len(C11) + n = n2 * 2 + C = [[0]*n for _ in range(n)] + for i in range(n2): + for j in range(n2): + C[i][j] = C11[i][j] + C[i][j + n2] = C12[i][j] + C[i + n2][j] = C21[i][j] + C[i + n2][j + n2] = C22[i][j] + return C + +def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix: + """ + Multiply square matrices A and B using Strassen algorithm. + threshold: below this size, uses naive multiplication (tweakable). + """ + assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation" + + n_orig = len(A) + if n_orig == 0: + return [] + + m = next_power_of_two(n_orig) + if m != n_orig: + A_pad = pad_matrix(A, m) + B_pad = pad_matrix(B, m) + else: + A_pad, B_pad = A, B + + C_pad = _strassen_recursive(A_pad, B_pad, threshold) + + C = unpad_matrix(C_pad, n_orig, n_orig) + return C + +def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: + n = len(A) + if n <= threshold: + return naive_mul(A, B) + if n == 1: + return [[A[0][0] * B[0][0]]] + + A11, A12, A21, A22 = split(A) + B11, B12, B21, B22 = split(B) + + M1 = _strassen_recursive(add(A11, A22), add(B11, B22), threshold) + M2 = _strassen_recursive(add(A21, A22), B11, threshold) + M3 = _strassen_recursive(A11, sub(B12, B22), threshold) + M4 = _strassen_recursive(A22, sub(B21, B11), threshold) + M5 = _strassen_recursive(add(A11, A12), B22, threshold) + M6 = _strassen_recursive(sub(A21, A11), add(B11, B12), threshold) + M7 = _strassen_recursive(sub(A12, A22), add(B21, B22), threshold) + + C11 = add(sub(add(M1, M4), M5), M7) + C12 = add(M3, M5) + C21 = add(M2, M4) + C22 = add(sub(add(M1, M3), M2), M6) + + return join(C11, C12, C21, C22) + +if __name__ == "__main__": + A = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ] + B = [ + [9, 8, 7], + [6, 5, 4], + [3, 2, 1] + ] + + C = strassen(A, B, threshold=1) + print("A * B =") + for row in C: + print(row) + + # verify against naive + expected = naive_mul(A, B) + assert C == expected, "Strassen result differs from naive multiplication!" + print("Verified: result matches naive multiplication.") From 980a8b80e04a6a4b0141ed736b642ed521260704 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 19:29:51 +0530 Subject: [PATCH 2/8] feat: Strassen's matrix multiplication algorithm added --- ...sen_matrix_multiplication.py => strassen_matrix_multiply.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename matrix/{strassen_matrix_multiplication.py => strassen_matrix_multiply.py} (98%) diff --git a/matrix/strassen_matrix_multiplication.py b/matrix/strassen_matrix_multiply.py similarity index 98% rename from matrix/strassen_matrix_multiplication.py rename to matrix/strassen_matrix_multiply.py index 708a7bbce9dc..8c5978423976 100644 --- a/matrix/strassen_matrix_multiplication.py +++ b/matrix/strassen_matrix_multiply.py @@ -129,4 +129,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: # verify against naive expected = naive_mul(A, B) assert C == expected, "Strassen result differs from naive multiplication!" - print("Verified: result matches naive multiplication.") + print("Verified: result matches naive multiplication.") \ No newline at end of file From d4a95a678dc7fbffcd9241a78370558d9fb95f3a Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 19:41:55 +0530 Subject: [PATCH 3/8] feat: Strassen's matrix multiplication algorithm added --- matrix/strassen_matrix_multiply.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 8c5978423976..5f71db7cdeed 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -1,3 +1,18 @@ +""" +Strassen's Matrix Multiplication Algorithm +------------------------------------------ +An optimized divide-and-conquer algorithm for matrix multiplication that +reduces the number of multiplications from 8 (in the naive approach) +to 7 per recursion step. + +This results in a time complexity of approximately O(n^2.807), +which is faster than the standard O(n^3) algorithm for large matrices. + +Reference: +https://en.wikipedia.org/wiki/Strassen_algorithm +""" + + from typing import List Matrix = List[List[int]] From 7889f30ca9605131a909733ccfdbdb297f4e0442 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Oct 2025 14:14:07 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- matrix/strassen_matrix_multiply.py | 39 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 5f71db7cdeed..3ffed1bfb8df 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -12,22 +12,24 @@ https://en.wikipedia.org/wiki/Strassen_algorithm """ - from typing import List Matrix = List[List[int]] + def add(A: Matrix, B: Matrix) -> Matrix: n = len(A) return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)] + def sub(A: Matrix, B: Matrix) -> Matrix: n = len(A) return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)] + def naive_mul(A: Matrix, B: Matrix) -> Matrix: n = len(A) - C = [[0]*n for _ in range(n)] + C = [[0] * n for _ in range(n)] for i in range(n): ai = A[i] ci = C[i] @@ -38,23 +40,27 @@ def naive_mul(A: Matrix, B: Matrix) -> Matrix: ci[j] += a_ik * bk[j] return C + def next_power_of_two(n: int) -> int: p = 1 while p < n: p <<= 1 return p + def pad_matrix(A: Matrix, size: int) -> Matrix: n = len(A) - padded = [[0]*size for _ in range(size)] + padded = [[0] * size for _ in range(size)] for i in range(n): for j in range(len(A[0])): padded[i][j] = A[i][j] return padded + def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix: return [row[:cols] for row in A[:rows]] + def split(A: Matrix) -> tuple: n = len(A) mid = n // 2 @@ -64,10 +70,11 @@ def split(A: Matrix) -> tuple: A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)] return A11, A12, A21, A22 + def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix: n2 = len(C11) n = n2 * 2 - C = [[0]*n for _ in range(n)] + C = [[0] * n for _ in range(n)] for i in range(n2): for j in range(n2): C[i][j] = C11[i][j] @@ -76,19 +83,21 @@ def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix: C[i + n2][j + n2] = C22[i][j] return C + def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix: """ Multiply square matrices A and B using Strassen algorithm. threshold: below this size, uses naive multiplication (tweakable). """ - assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation" + assert len(A) == len(A[0]) == len(B) == len(B[0]), ( + "Only square matrices supported in this implementation" + ) n_orig = len(A) if n_orig == 0: return [] - m = next_power_of_two(n_orig) - if m != n_orig: + if (m := next_power_of_two(n_orig)) != n_orig: A_pad = pad_matrix(A, m) B_pad = pad_matrix(B, m) else: @@ -99,6 +108,7 @@ def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix: C = unpad_matrix(C_pad, n_orig, n_orig) return C + def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: n = len(A) if n <= threshold: @@ -124,17 +134,10 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: return join(C11, C12, C21, C22) + if __name__ == "__main__": - A = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ] - B = [ - [9, 8, 7], - [6, 5, 4], - [3, 2, 1] - ] + A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]] C = strassen(A, B, threshold=1) print("A * B =") @@ -144,4 +147,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: # verify against naive expected = naive_mul(A, B) assert C == expected, "Strassen result differs from naive multiplication!" - print("Verified: result matches naive multiplication.") \ No newline at end of file + print("Verified: result matches naive multiplication.") From 981f0c970604846f91a726be1e307e647b9b8030 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 19:52:34 +0530 Subject: [PATCH 5/8] feat: Strassen's matrix multiplication algorithm added --- matrix/strassen_matrix_multiply.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 5f71db7cdeed..a6958b07df81 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -13,9 +13,7 @@ """ -from typing import List - -Matrix = List[List[int]] +Matrix = list[list[int]] def add(A: Matrix, B: Matrix) -> Matrix: n = len(A) From 6880ee5c666e604c8597167ca78b13fc9c68ca9c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Oct 2025 14:25:29 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- matrix/strassen_matrix_multiply.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 412e19245af2..43d41ad6cf1a 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -14,6 +14,7 @@ Matrix = list[list[int]] + def add(A: Matrix, B: Matrix) -> Matrix: n = len(A) return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)] From 46ef515566026f764eaad4a2f5b4660ed852eecf Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 20:02:37 +0530 Subject: [PATCH 7/8] feat: Strassen's matrix multiplication algorithm added --- matrix/strassen_matrix_multiply.py | 196 ++++++++++++++++++----------- 1 file changed, 123 insertions(+), 73 deletions(-) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 412e19245af2..637d6d68001a 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -14,122 +14,173 @@ Matrix = list[list[int]] -def add(A: Matrix, B: Matrix) -> Matrix: - n = len(A) - return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)] +def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: + """ + Add two square matrices of the same size. + + >>> add([[1,2],[3,4]], [[5,6],[7,8]]) + [[6, 8], [10, 12]] + """ + n = len(matrix_a) + return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)] + + +def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: + """ + Subtract matrix_b from matrix_a. + + >>> sub([[5,6],[7,8]], [[1,2],[3,4]]) + [[4, 4], [4, 4]] + """ + n = len(matrix_a) + return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)] -def sub(A: Matrix, B: Matrix) -> Matrix: - n = len(A) - return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)] +def naive_mul(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: + """ + Multiply two square matrices using the naive O(n^3) method. -def naive_mul(A: Matrix, B: Matrix) -> Matrix: - n = len(A) - C = [[0] * n for _ in range(n)] + >>> naive_mul([[1,2],[3,4]], [[5,6],[7,8]]) + [[19, 22], [43, 50]] + """ + n = len(matrix_a) + result = [[0] * n for _ in range(n)] for i in range(n): - ai = A[i] - ci = C[i] + row_a = matrix_a[i] + row_result = result[i] for k in range(n): - a_ik = ai[k] - bk = B[k] + a_ik = row_a[k] + col_b = matrix_b[k] for j in range(n): - ci[j] += a_ik * bk[j] - return C + row_result[j] += a_ik * col_b[j] + return result def next_power_of_two(n: int) -> int: - p = 1 - while p < n: - p <<= 1 - return p + """ + Return the next power of two greater than or equal to n. + + >>> next_power_of_two(5) + 8 + """ + power = 1 + while power < n: + power <<= 1 + return power + +def pad_matrix(matrix: Matrix, size: int) -> Matrix: + """ + Pad a matrix with zeros to reach the given size. -def pad_matrix(A: Matrix, size: int) -> Matrix: - n = len(A) + >>> pad_matrix([[1,2],[3,4]], 4) + [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] + """ + rows = len(matrix) + cols = len(matrix[0]) padded = [[0] * size for _ in range(size)] - for i in range(n): - for j in range(len(A[0])): - padded[i][j] = A[i][j] + for i in range(rows): + for j in range(cols): + padded[i][j] = matrix[i][j] return padded -def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix: - return [row[:cols] for row in A[:rows]] +def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix: + """ + Remove padding from a matrix. + + >>> unpad_matrix([[1,2,0],[3,4,0],[0,0,0]], 2, 2) + [[1, 2], [3, 4]] + """ + return [row[:cols] for row in matrix[:rows]] + +def split(matrix: Matrix) -> tuple: + """ + Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right). -def split(A: Matrix) -> tuple: - n = len(A) + >>> split([[1,2],[3,4]]) + ([[1]], [[2]], [[3]], [[4]]) + """ + n = len(matrix) mid = n // 2 - A11 = [[A[i][j] for j in range(mid)] for i in range(mid)] - A12 = [[A[i][j] for j in range(mid, n)] for i in range(mid)] - A21 = [[A[i][j] for j in range(mid)] for i in range(mid, n)] - A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)] - return A11, A12, A21, A22 + top_left = [[matrix[i][j] for j in range(mid)] for i in range(mid)] + top_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid)] + bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, n)] + bottom_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid, n)] + return top_left, top_right, bottom_left, bottom_right -def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix: - n2 = len(C11) +def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix: + """ + Join four quadrants into a single matrix. + + >>> join([[1]], [[2]], [[3]], [[4]]) + [[1, 2], [3, 4]] + """ + n2 = len(c11) n = n2 * 2 - C = [[0] * n for _ in range(n)] + result = [[0] * n for _ in range(n)] for i in range(n2): for j in range(n2): - C[i][j] = C11[i][j] - C[i][j + n2] = C12[i][j] - C[i + n2][j] = C21[i][j] - C[i + n2][j + n2] = C22[i][j] - return C + result[i][j] = c11[i][j] + result[i][j + n2] = c12[i][j] + result[i + n2][j] = c21[i][j] + result[i + n2][j + n2] = c22[i][j] + return result -def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix: +def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix: """ - Multiply square matrices A and B using Strassen algorithm. - threshold: below this size, uses naive multiplication (tweakable). + Multiply two square matrices using Strassen's algorithm. + Uses naive multiplication for matrices smaller than threshold. + + >>> strassen([[1,2],[3,4]], [[5,6],[7,8]]) + [[19, 22], [43, 50]] """ - assert len(A) == len(A[0]) == len(B) == len(B[0]), ( - "Only square matrices supported in this implementation" + assert len(matrix_a) == len(matrix_a[0]) == len(matrix_b) == len(matrix_b[0]), ( + "Only square matrices supported" ) - n_orig = len(A) + n_orig = len(matrix_a) if n_orig == 0: return [] if (m := next_power_of_two(n_orig)) != n_orig: - A_pad = pad_matrix(A, m) - B_pad = pad_matrix(B, m) + a_pad = pad_matrix(matrix_a, m) + b_pad = pad_matrix(matrix_b, m) else: - A_pad, B_pad = A, B - - C_pad = _strassen_recursive(A_pad, B_pad, threshold) + a_pad, b_pad = matrix_a, matrix_b - C = unpad_matrix(C_pad, n_orig, n_orig) - return C + c_pad = _strassen_recursive(a_pad, b_pad, threshold) + return unpad_matrix(c_pad, n_orig, n_orig) -def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: - n = len(A) +def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix: + n = len(matrix_a) if n <= threshold: - return naive_mul(A, B) + return naive_mul(matrix_a, matrix_b) if n == 1: - return [[A[0][0] * B[0][0]]] + return [[matrix_a[0][0] * matrix_b[0][0]]] - A11, A12, A21, A22 = split(A) - B11, B12, B21, B22 = split(B) + a11, a12, a21, a22 = split(matrix_a) + b11, b12, b21, b22 = split(matrix_b) - M1 = _strassen_recursive(add(A11, A22), add(B11, B22), threshold) - M2 = _strassen_recursive(add(A21, A22), B11, threshold) - M3 = _strassen_recursive(A11, sub(B12, B22), threshold) - M4 = _strassen_recursive(A22, sub(B21, B11), threshold) - M5 = _strassen_recursive(add(A11, A12), B22, threshold) - M6 = _strassen_recursive(sub(A21, A11), add(B11, B12), threshold) - M7 = _strassen_recursive(sub(A12, A22), add(B21, B22), threshold) + m1 = _strassen_recursive(add(a11, a22), add(b11, b22), threshold) + m2 = _strassen_recursive(add(a21, a22), b11, threshold) + m3 = _strassen_recursive(a11, sub(b12, b22), threshold) + m4 = _strassen_recursive(a22, sub(b21, b11), threshold) + m5 = _strassen_recursive(add(a11, a12), b22, threshold) + m6 = _strassen_recursive(sub(a21, a11), add(b11, b12), threshold) + m7 = _strassen_recursive(sub(a12, a22), add(b21, b22), threshold) - C11 = add(sub(add(M1, M4), M5), M7) - C12 = add(M3, M5) - C21 = add(M2, M4) - C22 = add(sub(add(M1, M3), M2), M6) + c11 = add(sub(add(m1, m4), m5), m7) + c12 = add(m3, m5) + c21 = add(m2, m4) + c22 = add(sub(add(m1, m3), m2), m6) - return join(C11, C12, C21, C22) + return join(c11, c12, c21, c22) if __name__ == "__main__": @@ -141,7 +192,6 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: for row in C: print(row) - # verify against naive expected = naive_mul(A, B) assert C == expected, "Strassen result differs from naive multiplication!" print("Verified: result matches naive multiplication.") From 144024b3bc088e6c2695fe741ee58b56dc90e330 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 13 Oct 2025 20:15:51 +0530 Subject: [PATCH 8/8] feat: Strassen's matrix multiplication algorithm added --- matrix/strassen_matrix_multiply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matrix/strassen_matrix_multiply.py b/matrix/strassen_matrix_multiply.py index 637d6d68001a..62cd8924f3a8 100644 --- a/matrix/strassen_matrix_multiply.py +++ b/matrix/strassen_matrix_multiply.py @@ -193,5 +193,5 @@ def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> M print(row) expected = naive_mul(A, B) - assert C == expected, "Strassen result differs from naive multiplication!" + assert expected == C, "Strassen result differs from naive multiplication!" print("Verified: result matches naive multiplication.")