projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
f087787
)
The "mask" array actually specifies what attention to discard.
author
Francois Fleuret
<francois@fleuret.org>
Sat, 27 Aug 2022 09:28:03 +0000
(11:28 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sat, 27 Aug 2022 09:28:03 +0000
(11:28 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
ebc9a83
..
f954797
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-66,9
+66,9
@@
class QKVAttention(nn.Module):
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
-
mask
= torch.arange(a.size(2), device = q.device)[None, None, :, None] \
- < torch.arange(a.size(3), device = q.device)[None, None, None, :]
- a = a.masked_fill(
mask
, float('-inf'))
+
forbidden_attention
= torch.arange(a.size(2), device = q.device)[None, None, :, None] \
+
< torch.arange(a.size(3), device = q.device)[None, None, None, :]
+ a = a.masked_fill(
forbidden_attention
, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)