Skip to content

Commit 7889f30

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d4a95a6 commit 7889f30

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@
1212
https://en.wikipedia.org/wiki/Strassen_algorithm
1313
"""
1414

15-
1615
from typing import List
1716

1817
Matrix = List[List[int]]
1918

19+
2020
def add(A: Matrix, B: Matrix) -> Matrix:
2121
n = len(A)
2222
return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
2323

24+
2425
def sub(A: Matrix, B: Matrix) -> Matrix:
2526
n = len(A)
2627
return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
2728

29+
2830
def naive_mul(A: Matrix, B: Matrix) -> Matrix:
2931
n = len(A)
30-
C = [[0]*n for _ in range(n)]
32+
C = [[0] * n for _ in range(n)]
3133
for i in range(n):
3234
ai = A[i]
3335
ci = C[i]
@@ -38,23 +40,27 @@ def naive_mul(A: Matrix, B: Matrix) -> Matrix:
3840
ci[j] += a_ik * bk[j]
3941
return C
4042

43+
4144
def next_power_of_two(n: int) -> int:
4245
p = 1
4346
while p < n:
4447
p <<= 1
4548
return p
4649

50+
4751
def pad_matrix(A: Matrix, size: int) -> Matrix:
4852
n = len(A)
49-
padded = [[0]*size for _ in range(size)]
53+
padded = [[0] * size for _ in range(size)]
5054
for i in range(n):
5155
for j in range(len(A[0])):
5256
padded[i][j] = A[i][j]
5357
return padded
5458

59+
5560
def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix:
5661
return [row[:cols] for row in A[:rows]]
5762

63+
5864
def split(A: Matrix) -> tuple:
5965
n = len(A)
6066
mid = n // 2
@@ -64,10 +70,11 @@ def split(A: Matrix) -> tuple:
6470
A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)]
6571
return A11, A12, A21, A22
6672

73+
6774
def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
6875
n2 = len(C11)
6976
n = n2 * 2
70-
C = [[0]*n for _ in range(n)]
77+
C = [[0] * n for _ in range(n)]
7178
for i in range(n2):
7279
for j in range(n2):
7380
C[i][j] = C11[i][j]
@@ -76,19 +83,21 @@ def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
7683
C[i + n2][j + n2] = C22[i][j]
7784
return C
7885

86+
7987
def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
8088
"""
8189
Multiply square matrices A and B using Strassen algorithm.
8290
threshold: below this size, uses naive multiplication (tweakable).
8391
"""
84-
assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation"
92+
assert len(A) == len(A[0]) == len(B) == len(B[0]), (
93+
"Only square matrices supported in this implementation"
94+
)
8595

8696
n_orig = len(A)
8797
if n_orig == 0:
8898
return []
8999

90-
m = next_power_of_two(n_orig)
91-
if m != n_orig:
100+
if (m := next_power_of_two(n_orig)) != n_orig:
92101
A_pad = pad_matrix(A, m)
93102
B_pad = pad_matrix(B, m)
94103
else:
@@ -99,6 +108,7 @@ def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
99108
C = unpad_matrix(C_pad, n_orig, n_orig)
100109
return C
101110

111+
102112
def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
103113
n = len(A)
104114
if n <= threshold:
@@ -124,17 +134,10 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
124134

125135
return join(C11, C12, C21, C22)
126136

137+
127138
if __name__ == "__main__":
128-
A = [
129-
[1, 2, 3],
130-
[4, 5, 6],
131-
[7, 8, 9]
132-
]
133-
B = [
134-
[9, 8, 7],
135-
[6, 5, 4],
136-
[3, 2, 1]
137-
]
139+
A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
140+
B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]]
138141

139142
C = strassen(A, B, threshold=1)
140143
print("A * B =")
@@ -144,4 +147,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
144147
# verify against naive
145148
expected = naive_mul(A, B)
146149
assert C == expected, "Strassen result differs from naive multiplication!"
147-
print("Verified: result matches naive multiplication.")
150+
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)