- Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathattend.py
124 lines (87 loc) · 3.34 KB
/
attend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
fromfunctoolsimportwraps
frompackagingimportversion
fromcollectionsimportnamedtuple
importtorch
fromtorchimportnn, einsum
importtorch.nn.functionalasF
fromeinopsimportrearrange
# constants
AttentionConfig=namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
defexists(val):
returnvalisnotNone
defdefault(val, d):
returnvalifexists(val) elsed
defonce(fn):
called=False
@wraps(fn)
definner(x):
nonlocalcalled
ifcalled:
return
called=True
returnfn(x)
returninner
print_once=once(print)
# main class
classAttend(nn.Module):
def__init__(
self,
dropout=0.,
flash=False,
scale=None
):
super().__init__()
self.dropout=dropout
self.scale=scale
self.attn_dropout=nn.Dropout(dropout)
self.flash=flash
assertnot (flashandversion.parse(torch.__version__) <version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config=AttentionConfig(True, True, True)
self.cuda_config=None
ifnottorch.cuda.is_available() ornotflash:
return
device_properties=torch.cuda.get_device_properties(torch.device('cuda'))
device_version=version.parse(f'{device_properties.major}.{device_properties.minor}')
ifdevice_version>version.parse('8.0'):
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config=AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config=AttentionConfig(False, True, True)
defflash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device=*q.shape, k.shape[-2], q.is_cuda, q.device
ifexists(self.scale):
default_scale=q.shape[-1]
q=q* (self.scale/default_scale)
q, k, v=map(lambdat: t.contiguous(), (q, k, v))
# Check if there is a compatible device for flash attention
config=self.cuda_configifis_cudaelseself.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
withtorch.backends.cuda.sdp_kernel(**config._asdict()):
out=F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropoutifself.trainingelse0.
)
returnout
defforward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device=q.shape[-2], k.shape[-2], q.device
ifself.flash:
returnself.flash_attn(q, k, v)
scale=default(self.scale, q.shape[-1] **-0.5)
# similarity
sim=einsum(f"b h i d, b h j d -> b h i j", q, k) *scale
# attention
attn=sim.softmax(dim=-1)
attn=self.attn_dropout(attn)
# aggregate values
out=einsum(f"b h i j, b h j d -> b h i d", attn, v)
returnout