OCDC
[mygpt.git] / mygpt.py
index ab16e1e..43711b3 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -107,7 +107,8 @@ class MyGPT(nn.Module):
                     nn.LayerNorm(dim_model),
                     QKVAttention(
                         dim_in = dim_model,
-                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+                        dim_qk = dim_keys,
+                        dim_v = dim_model // nb_heads,
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),