- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
235 lines (180 loc) · 10.7 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
importtorch
fromtorchimportnn
fromtorch.nnimportfunctionalasF
importmath
"""
SelfAttention 自注意力机制
SelfAttention 类实现了自注意力机制,用于计算序列中每个元素与其他元素之间的关系,从而捕捉序列内部的依赖关系。
一 初始化参数: n_heads, d_embed, in_proj_bias=True, out_proj_bias=True
二 主要组件: self.in_proj, self.out_proj, self.n_heads, self.d_head
三 前向传播方法 (forward)
详细步骤:
1.输入投影:
将输入 x 通过线性层 self.in_proj 投影到 Q, K, V 三个向量,每个向量的维度为 d_embed。使用 chunk(3, dim=-1) 将投影后的向量分割成 Q, K, V。
2.调整形状以适应多头注意力:
将 Q, K, V 的形状从 (Batch_Size, Seq_Len, Dim) 调整为 (Batch_Size, Seq_Len, H, Dim/H),然后转置为 (Batch_Size, H, Seq_Len, Dim/H),以便进行多头注意力计算。
3.计算注意力权重:
计算 Q 和 K 的点积,得到注意力权重 weight,形状为 (Batch_Size, H, Seq_Len, Seq_Len)。
4.应用因果掩码(可选):
如果 causal_mask 为 True,则使用上三角掩码将上三角部分的注意力权重设为 -inf,以防止模型看到未来的信息。这在自回归模型(如 GPT)中常用。
5.缩放和 softmax:
对注意力权重进行缩放,除以 sqrt(d_head),以防止数值不稳定。应用 softmax 激活函数,将权重归一化。
6.计算最终输出:
将注意力权重与 V 相乘,得到加权后的值 output,形状为 (Batch_Size, H, Seq_Len, Dim/H)。调整形状为 (Batch_Size, Seq_Len, Dim)。通过线性层 self.out_proj 投影回原始的嵌入维度。
"""
classSelfAttention(nn.Module):
'''
n_heads(int): 注意力头的数量。将输入的嵌入维度分成多个头,每个头独立计算注意力。
d_embed(int): 输入嵌入的维度大小。
in_proj_bias(bool): 是否在输入投影层中添加偏置项。默认为 True。
out_proj_bias(bool): 是否在输出投影层中添加偏置项。默认为 True。
'''
def__init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# This combines the Wq, Wk and Wv matrices into one matrix
# 将输入的嵌入向量线性投影到查询(Query)、键(Key)和值(Value)三个向量。
# 线性层的输出维度为 3 * d_embed,因为它同时生成 Q、K、V 三个向量。
self.in_proj=nn.Linear(d_embed, 3*d_embed, bias=in_proj_bias)
# This one represents the Wo matrix
# 将注意力机制输出的结果线性投影回原始的嵌入维度 d_embed。
self.out_proj=nn.Linear(d_embed, d_embed, bias=out_proj_bias)
# 注意力头的数量
self.n_heads=n_heads
# 每个注意力头的维度大小,计算方式为 d_embed // n_heads
self.d_head=d_embed//n_heads
defforward(self, x, causal_mask=False):
# x: # (Batch_Size, Seq_Len, Dim)
# 输入 x 的形状: (Batch_Size, Seq_Len, Dim)
# (Batch_Size, Seq_Len, Dim)
input_shape=x.shape
# (Batch_Size, Seq_Len, Dim)
batch_size, sequence_length, d_embed=input_shape
# (Batch_Size, Seq_Len, H, Dim / H)
# 重新调整形状以适应多头注意力
interim_shape= (batch_size, sequence_length, self.n_heads, self.d_head)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
# 输入投影:将 x 投影到 Q, K, V
q, k, v=self.in_proj(x).chunk(3, dim=-1) # 每个的形状: (Batch_Size, Seq_Len, Dim)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
# 调整 Q, K, V 的形状以适应多头注意力
q=q.view(interim_shape).transpose(1, 2)
k=k.view(interim_shape).transpose(1, 2)
v=v.view(interim_shape).transpose(1, 2)
# (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
# 计算注意力权重:Q @ K^T
weight=q @ k.transpose(-1, -2)
# 如果使用因果掩码,则将上三角部分设为负无穷
ifcausal_mask:
# Mask where the upper triangle (above the principal diagonal) is 1
mask=torch.ones_like(weight, dtype=torch.bool).triu(1) # 上三角为 True
# Fill the upper triangle with -inf
# 上三角填充为 -inf
weight.masked_fill_(mask, -torch.inf)
# Divide by d_k (Dim / H).
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
# 缩放注意力权重
weight/=math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
# 应用 softmax 激活函数
weight=F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
# 计算注意力输出:weight @ V
output=weight @ v
# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
# 调整输出形状
output=output.transpose(1, 2)
# (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
output=output.reshape(input_shape)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
# 线性投影输出
output=self.out_proj(output)
# (Batch_Size, Seq_Len, Dim)
returnoutput
"""
CrossAttention 交叉注意力机制
CrossAttention 类实现了交叉注意力机制,用于在两个不同的序列之间建立关联。
一 初始化参数: n_heads, d_embed, d_cross, in_proj_bias, out_proj_bias
二 主要组件: self.q_proj, self.k_proj, self.v_proj, self.out_proj, self.n_heads, self.d_head
三 前向传播方法 (forward)
详细步骤:
1.输入投影:
将查询序列 x 通过线性层 self.q_proj 投影到 Q,维度为 d_embed。
将上下文序列 y 通过线性层 self.k_proj 和 self.v_proj 分别投影到 K 和 V,维度均为 d_embed。
2.调整形状以适应多头注意力:
将 Q, K, V 的形状从 (Batch_Size, Seq_Len, Dim) 调整为 (Batch_Size, H, Seq_Len, Dim/H),以便进行多头注意力计算。
3.计算注意力权重:
计算 Q 和 K 的点积,得到注意力权重 weight,形状为 (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)。
4.缩放和 softmax:
对注意力权重进行缩放,除以 sqrt(d_head)。应用 softmax 激活函数,将权重归一化。
5.计算最终输出:
将注意力权重与 V 相乘,得到加权后的值 output,形状为 (Batch_Size, H, Seq_Len_Q, Dim_Q/H)。
调整形状为 (Batch_Size, Seq_Len_Q, Dim_Q)。通过线性层 self.out_proj 投影回原始的嵌入维度。
"""
classCrossAttention(nn.Module):
'''
n_heads(int): 注意力头的数量。
d_embed(int): 查询(Query)嵌入的维度大小。
d_cross(int): 键(Key)和值(Value)嵌入的维度大小。
in_proj_bias(bool): 是否在 Q, K, V 投影层中添加偏置项。默认为 True。
out_proj_bias(bool): 是否在输出投影层中添加偏置项。默认为 True。
'''
def__init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# 将查询序列 x 投影到查询向量 Q,维度为 d_embed
self.q_proj=nn.Linear(d_embed, d_embed, bias=in_proj_bias)
# 将上下文序列 y 投影到键向量 K,维度为 d_embed
self.k_proj=nn.Linear(d_cross, d_embed, bias=in_proj_bias)
# 将上下文序列 y 投影到值向量 V,维度为 d_embed
self.v_proj=nn.Linear(d_cross, d_embed, bias=in_proj_bias)
# 将注意力机制的输出线性投影回原始的嵌入维度 d_embed
self.out_proj=nn.Linear(d_embed, d_embed, bias=out_proj_bias)
# 注意力头的数量
self.n_heads=n_heads
# 每个注意力头的维度大小,计算方式为 d_embed // n_heads
self.d_head=d_embed//n_heads
defforward(self, x, y):
# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
# 查询 x 的形状: (Batch_Size, Seq_Len_Q, Dim_Q)
# 上下文 y 的形状: (Batch_Size, Seq_Len_KV, Dim_KV)
input_shape=x.shape
batch_size, sequence_length, d_embed=input_shape
# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
# 调整形状以适应多头注意力
interim_shape= (batch_size, -1, self.n_heads, self.d_head)
# 投影 Q, K, V
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
q=self.q_proj(x)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
k=self.k_proj(y)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
v=self.v_proj(y)
# 调整 Q, K, V 的形状以适应多头注意力
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
q=q.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
k=k.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
v=v.view(interim_shape).transpose(1, 2)
# 计算注意力权重:Q @ K^T
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight=q @ k.transpose(-1, -2)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
# 缩放注意力权重
weight/=math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
# 应用 softmax 激活函数
weight=F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
# 计算注意力输出:weight @ V
output=weight @ v
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
# 调整输出形状
output=output.transpose(1, 2).contiguous()
# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output=output.view(input_shape)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
# 线性投影输出
output=self.out_proj(output)
# (Batch_Size, Seq_Len_Q, Dim_Q)
returnoutput