projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
a1a7cb9
)
Fixed the size of w_o.
author
Francois Fleuret
<francois@fleuret.org>
Tue, 26 Jul 2022 15:06:13 +0000
(17:06 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Tue, 26 Jul 2022 15:06:13 +0000
(17:06 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
121ad80
..
ab16e1e
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-57,7
+57,7
@@
class QKVAttention(nn.Module):
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_
in, dim_v * nb_heads
)
+ self.w_o = randw(dim_
v * nb_heads, dim_in
)
def forward(self, x_q, x_kv = None):
if x_kv is None: x_kv = x_q
def forward(self, x_q, x_kv = None):
if x_kv is None: x_kv = x_q
@@
-142,7
+142,7
@@
if __name__ == '__main__':
model = MyGPT(
vocabulary_size = vocabulary_size,
model = MyGPT(
vocabulary_size = vocabulary_size,
- dim_model = 1
6
, dim_keys = 50, dim_hidden = 100,
+ dim_model = 1
8
, dim_keys = 50, dim_hidden = 100,
nb_heads = 2, nb_blocks = 3,
dropout = 0.1
)
nb_heads = 2, nb_blocks = 3,
dropout = 0.1
)