X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=ab16e1e184bdc513f5cd189112cf78908d76abd0;hb=12184f604b37f36f07d7dcdd567b1c78f02c74db;hp=5370ffa5d57a876a8d51106bcb0c2e33bf5b3c28;hpb=84748a01e6d3c26037412592ce147b7753ce6117;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 5370ffa..ab16e1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -57,7 +57,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q @@ -126,7 +126,7 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): - x = torch.cat((x.new_zeros(x.size(0), 1), x), 1) + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) @@ -142,7 +142,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )