projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
02a6cbf
)
Initialize properly w_o.
author
Francois Fleuret
<francois@fleuret.org>
Mon, 25 Jul 2022 16:14:57 +0000
(18:14 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 25 Jul 2022 16:14:57 +0000
(18:14 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
42960b1
..
57cbbc6
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-51,7
+51,7
@@
class QKVAttention(nn.Module):
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(
nb_heads, dim_in, dim_v
)
+ self.w_o = randw(
dim_in, dim_v * nb_heads
)
self.causal = causal
self.attention_dropout = attention_dropout
self.causal = causal
self.attention_dropout = attention_dropout
@@
-67,8
+67,8
@@
class QKVAttention(nn.Module):
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
a = a.masked_fill(mask, float('-inf'))
a = a.softmax(dim = 3)
a = F.dropout(a, self.attention_dropout, self.training)
- y = torch.einsum('nhts,nhsd->n
ht
d', a, v)
- y =
torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+ y = torch.einsum('nhts,nhsd->n
th
d', a, v)
+ y =
y.flatten(2) @ self.w_o
return y
return y