Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index ec50722..4d5077a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -87,6 +87,8 @@ parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--attention", type=str, default=None)
 
+parser.add_argument("--proportion_memex", type=float, default=0)
+
 parser.add_argument("--dim_model", type=int, default=None)
 
 parser.add_argument("--dim_keys", type=int, default=None)
@@ -101,9 +103,9 @@ parser.add_argument("--caterpillar_height", type=int, default=None)
 
 parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
 
-parser.add_argument("--gate_dropout_sync", type=str2bool, default=True)
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
 
-parser.add_argument("--gate_dropout_replace", type=str2bool, default=True)
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
 
 parser.add_argument("--rho_inner_loss", type=float, default=0.0)
 
@@ -736,6 +738,9 @@ log_string(f"device {device}")
 
 vocabulary_size = task.vocabulary_size()
 
+if args.proportion_memex > 0:
+    vocabulary_size += 1
+
 log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
@@ -897,7 +902,24 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
 
-    for input in task.batches(split="train"):
+    def add_memex(batches, proportion_memex):
+        for input in batches:
+            if torch.rand(1).item() < proportion_memex:
+                yield torch.cat(
+                    [
+                        input,
+                        torch.full(
+                            (input.size(0), 1), vocabulary_size - 1, device=input.device
+                        ),
+                        input,
+                    ],
+                    dim=1,
+                )
+            yield input
+
+    train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
+
+    for input in train_batches:
         model.reset_inner_loss()
         input = input.to(device)