1212https://en.wikipedia.org/wiki/Strassen_algorithm
1313"""
1414
15-
1615from typing import List
1716
1817Matrix = List [List [int ]]
1918
19+
2020def add (A : Matrix , B : Matrix ) -> Matrix :
2121 n = len (A )
2222 return [[A [i ][j ] + B [i ][j ] for j in range (n )] for i in range (n )]
2323
24+
2425def sub (A : Matrix , B : Matrix ) -> Matrix :
2526 n = len (A )
2627 return [[A [i ][j ] - B [i ][j ] for j in range (n )] for i in range (n )]
2728
29+
2830def naive_mul (A : Matrix , B : Matrix ) -> Matrix :
2931 n = len (A )
30- C = [[0 ]* n for _ in range (n )]
32+ C = [[0 ] * n for _ in range (n )]
3133 for i in range (n ):
3234 ai = A [i ]
3335 ci = C [i ]
@@ -38,23 +40,27 @@ def naive_mul(A: Matrix, B: Matrix) -> Matrix:
3840 ci [j ] += a_ik * bk [j ]
3941 return C
4042
43+
4144def next_power_of_two (n : int ) -> int :
4245 p = 1
4346 while p < n :
4447 p <<= 1
4548 return p
4649
50+
4751def pad_matrix (A : Matrix , size : int ) -> Matrix :
4852 n = len (A )
49- padded = [[0 ]* size for _ in range (size )]
53+ padded = [[0 ] * size for _ in range (size )]
5054 for i in range (n ):
5155 for j in range (len (A [0 ])):
5256 padded [i ][j ] = A [i ][j ]
5357 return padded
5458
59+
5560def unpad_matrix (A : Matrix , rows : int , cols : int ) -> Matrix :
5661 return [row [:cols ] for row in A [:rows ]]
5762
63+
5864def split (A : Matrix ) -> tuple :
5965 n = len (A )
6066 mid = n // 2
@@ -64,10 +70,11 @@ def split(A: Matrix) -> tuple:
6470 A22 = [[A [i ][j ] for j in range (mid , n )] for i in range (mid , n )]
6571 return A11 , A12 , A21 , A22
6672
73+
6774def join (C11 : Matrix , C12 : Matrix , C21 : Matrix , C22 : Matrix ) -> Matrix :
6875 n2 = len (C11 )
6976 n = n2 * 2
70- C = [[0 ]* n for _ in range (n )]
77+ C = [[0 ] * n for _ in range (n )]
7178 for i in range (n2 ):
7279 for j in range (n2 ):
7380 C [i ][j ] = C11 [i ][j ]
@@ -76,19 +83,21 @@ def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
7683 C [i + n2 ][j + n2 ] = C22 [i ][j ]
7784 return C
7885
86+
7987def strassen (A : Matrix , B : Matrix , threshold : int = 64 ) -> Matrix :
8088 """
8189 Multiply square matrices A and B using Strassen algorithm.
8290 threshold: below this size, uses naive multiplication (tweakable).
8391 """
84- assert len (A ) == len (A [0 ]) == len (B ) == len (B [0 ]), "Only square matrices supported in this implementation"
92+ assert len (A ) == len (A [0 ]) == len (B ) == len (B [0 ]), (
93+ "Only square matrices supported in this implementation"
94+ )
8595
8696 n_orig = len (A )
8797 if n_orig == 0 :
8898 return []
8999
90- m = next_power_of_two (n_orig )
91- if m != n_orig :
100+ if (m := next_power_of_two (n_orig )) != n_orig :
92101 A_pad = pad_matrix (A , m )
93102 B_pad = pad_matrix (B , m )
94103 else :
@@ -99,6 +108,7 @@ def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
99108 C = unpad_matrix (C_pad , n_orig , n_orig )
100109 return C
101110
111+
102112def _strassen_recursive (A : Matrix , B : Matrix , threshold : int ) -> Matrix :
103113 n = len (A )
104114 if n <= threshold :
@@ -124,17 +134,10 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
124134
125135 return join (C11 , C12 , C21 , C22 )
126136
137+
127138if __name__ == "__main__" :
128- A = [
129- [1 , 2 , 3 ],
130- [4 , 5 , 6 ],
131- [7 , 8 , 9 ]
132- ]
133- B = [
134- [9 , 8 , 7 ],
135- [6 , 5 , 4 ],
136- [3 , 2 , 1 ]
137- ]
139+ A = [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]
140+ B = [[9 , 8 , 7 ], [6 , 5 , 4 ], [3 , 2 , 1 ]]
138141
139142 C = strassen (A , B , threshold = 1 )
140143 print ("A * B =" )
@@ -144,4 +147,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
144147 # verify against naive
145148 expected = naive_mul (A , B )
146149 assert C == expected , "Strassen result differs from naive multiplication!"
147- print ("Verified: result matches naive multiplication." )
150+ print ("Verified: result matches naive multiplication." )
0 commit comments