Update.
[mygpt.git] / mygpt.py
index 212e1a5..d6879dc 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -104,7 +104,7 @@ class MyGPT(nn.Module):
         for _ in range(nb_blocks):
             trunk_blocks += [
                 Residual(
-                    nn.LayerNorm(dim_model),
+                    nn.LayerNorm((dim_model,)),
                     QKVAttention(
                         dim_in = dim_model,
                         dim_qk = dim_keys,
@@ -114,7 +114,7 @@ class MyGPT(nn.Module):
                     ),
                 ),
                 Residual(
-                    nn.LayerNorm(dim_model),
+                    nn.LayerNorm((dim_model,)),
                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
                     nn.ReLU(),
                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
@@ -131,7 +131,8 @@ class MyGPT(nn.Module):
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        return x[:, :-1]
+        x = F.pad(x, (0, 0, 0, -1))
+        return x
 
 ######################################################################