- Notifications
You must be signed in to change notification settings - Fork 46.7k
/
Copy pathstrassen_matrix_multiplication.py
172 lines (139 loc) · 5.93 KB
/
strassen_matrix_multiplication.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ importannotations
importmath
defdefault_matrix_multiplication(a: list, b: list) ->list:
"""
Multiplication only for 2x2 matrices
"""
iflen(a) !=2orlen(a[0]) !=2orlen(b) !=2orlen(b[0]) !=2:
raiseException("Matrices are not 2x2")
new_matrix= [
[a[0][0] *b[0][0] +a[0][1] *b[1][0], a[0][0] *b[0][1] +a[0][1] *b[1][1]],
[a[1][0] *b[0][0] +a[1][1] *b[1][0], a[1][0] *b[0][1] +a[1][1] *b[1][1]],
]
returnnew_matrix
defmatrix_addition(matrix_a: list, matrix_b: list):
return [
[matrix_a[row][col] +matrix_b[row][col] forcolinrange(len(matrix_a[row]))]
forrowinrange(len(matrix_a))
]
defmatrix_subtraction(matrix_a: list, matrix_b: list):
return [
[matrix_a[row][col] -matrix_b[row][col] forcolinrange(len(matrix_a[row]))]
forrowinrange(len(matrix_a))
]
defsplit_matrix(a: list) ->tuple[list, list, list, list]:
"""
Given an even length matrix, returns the top_left, top_right, bot_left, bot_right
quadrant.
>>> split_matrix([[4,3,2,4],[2,3,1,1],[6,5,4,3],[8,4,1,6]])
([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]])
>>> split_matrix([
... [4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6],
... [4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6]
... ]) # doctest: +NORMALIZE_WHITESPACE
([[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4],
[2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1],
[6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3],
[8, 4, 1, 6]])
"""
iflen(a) %2!=0orlen(a[0]) %2!=0:
raiseException("Odd matrices are not supported!")
matrix_length=len(a)
mid=matrix_length//2
top_right= [[a[i][j] forjinrange(mid, matrix_length)] foriinrange(mid)]
bot_right= [
[a[i][j] forjinrange(mid, matrix_length)] foriinrange(mid, matrix_length)
]
top_left= [[a[i][j] forjinrange(mid)] foriinrange(mid)]
bot_left= [[a[i][j] forjinrange(mid)] foriinrange(mid, matrix_length)]
returntop_left, top_right, bot_left, bot_right
defmatrix_dimensions(matrix: list) ->tuple[int, int]:
returnlen(matrix), len(matrix[0])
defprint_matrix(matrix: list) ->None:
print("\n".join(str(line) forlineinmatrix))
defactual_strassen(matrix_a: list, matrix_b: list) ->list:
"""
Recursive function to calculate the product of two matrices, using the Strassen
Algorithm. It only supports square matrices of any size that is a power of 2.
"""
ifmatrix_dimensions(matrix_a) == (2, 2):
returndefault_matrix_multiplication(matrix_a, matrix_b)
a, b, c, d=split_matrix(matrix_a)
e, f, g, h=split_matrix(matrix_b)
t1=actual_strassen(a, matrix_subtraction(f, h))
t2=actual_strassen(matrix_addition(a, b), h)
t3=actual_strassen(matrix_addition(c, d), e)
t4=actual_strassen(d, matrix_subtraction(g, e))
t5=actual_strassen(matrix_addition(a, d), matrix_addition(e, h))
t6=actual_strassen(matrix_subtraction(b, d), matrix_addition(g, h))
t7=actual_strassen(matrix_subtraction(a, c), matrix_addition(e, f))
top_left=matrix_addition(matrix_subtraction(matrix_addition(t5, t4), t2), t6)
top_right=matrix_addition(t1, t2)
bot_left=matrix_addition(t3, t4)
bot_right=matrix_subtraction(matrix_subtraction(matrix_addition(t1, t5), t3), t7)
# construct the new matrix from our 4 quadrants
new_matrix= []
foriinrange(len(top_right)):
new_matrix.append(top_left[i] +top_right[i])
foriinrange(len(bot_right)):
new_matrix.append(bot_left[i] +bot_right[i])
returnnew_matrix
defstrassen(matrix1: list, matrix2: list) ->list:
"""
>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
[[139, 163], [121, 134], [100, 121]]
"""
ifmatrix_dimensions(matrix1)[1] !=matrix_dimensions(matrix2)[0]:
msg= (
"Unable to multiply these matrices, please check the dimensions.\n"
f"Matrix A: {matrix1}\n"
f"Matrix B: {matrix2}"
)
raiseException(msg)
dimension1=matrix_dimensions(matrix1)
dimension2=matrix_dimensions(matrix2)
ifdimension1[0] ==dimension1[1] anddimension2[0] ==dimension2[1]:
return [matrix1, matrix2]
maximum=max(*dimension1, *dimension2)
maxim=int(math.pow(2, math.ceil(math.log2(maximum))))
new_matrix1=matrix1
new_matrix2=matrix2
# Adding zeros to the matrices to convert them both into square matrices of equal
# dimensions that are a power of 2
foriinrange(maxim):
ifi<dimension1[0]:
for_inrange(dimension1[1], maxim):
new_matrix1[i].append(0)
else:
new_matrix1.append([0] *maxim)
ifi<dimension2[0]:
for_inrange(dimension2[1], maxim):
new_matrix2[i].append(0)
else:
new_matrix2.append([0] *maxim)
final_matrix=actual_strassen(new_matrix1, new_matrix2)
# Removing the additional zeros
foriinrange(maxim):
ifi<dimension1[0]:
for_inrange(dimension2[1], maxim):
final_matrix[i].pop()
else:
final_matrix.pop()
returnfinal_matrix
if__name__=="__main__":
matrix1= [
[2, 3, 4, 5],
[6, 4, 3, 1],
[2, 3, 6, 7],
[3, 1, 2, 4],
[2, 3, 4, 5],
[6, 4, 3, 1],
[2, 3, 6, 7],
[3, 1, 2, 4],
[2, 3, 4, 5],
[6, 2, 3, 1],
]
matrix2= [[0, 2, 1, 1], [16, 2, 3, 3], [2, 2, 7, 7], [13, 11, 22, 4]]
print(strassen(matrix1, matrix2))