projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
b4593ce
)
Added the (missing) W_o
author
Francois Fleuret
<francois@fleuret.org>
Mon, 25 Jul 2022 13:31:18 +0000
(15:31 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 25 Jul 2022 13:31:18 +0000
(15:31 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
4951460
..
42960b1
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-51,6
+51,7
@@
class QKVAttention(nn.Module):
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
+ self.w_o = randw(nb_heads, dim_in, dim_v)
self.causal = causal
self.attention_dropout = attention_dropout
self.causal = causal
self.attention_dropout = attention_dropout
@@
-61,12
+62,15
@@
class QKVAttention(nn.Module):
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
- mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
+ mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
+ < torch.arange(a.size(3), device = q.device)[None, None, None, :]
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum('nhts,nhsd->nhtd', a, v)
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum('nhts,nhsd->nhtd', a, v)
- return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+ y = torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+
+ return y
##############################
##############################
@@
-122,6
+126,8
@@
class MyGPT(nn.Module):
######################################################################
if __name__ == '__main__':
######################################################################
if __name__ == '__main__':
+ print('Basic check.')
+
vocabulary_size = 10
x = torch.randint(vocabulary_size, (25, 100))
vocabulary_size = 10
x = torch.randint(vocabulary_size, (25, 100))