Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 8 Feb 2024 06:24:13 +0000 (07:24 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 8 Feb 2024 06:24:13 +0000 (07:24 +0100)
main.py

diff --git a/main.py b/main.py
index 4d5077a..91c885b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -87,7 +87,7 @@ parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--attention", type=str, default=None)
 
-parser.add_argument("--proportion_memex", type=float, default=0)
+parser.add_argument("--proba_memex", type=float, default=0)
 
 parser.add_argument("--dim_model", type=int, default=None)
 
@@ -738,7 +738,7 @@ log_string(f"device {device}")
 
 vocabulary_size = task.vocabulary_size()
 
-if args.proportion_memex > 0:
+if args.proba_memex > 0:
     vocabulary_size += 1
 
 log_string(f"vocabulary_size {vocabulary_size}")
@@ -902,22 +902,26 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
 
-    def add_memex(batches, proportion_memex):
+    def add_memex(batches, proba_memex):
         for input in batches:
-            if torch.rand(1).item() < proportion_memex:
+            if torch.rand(1).item() < proba_memex:
+                sep = (
+                    torch.full(
+                        (input.size(0), 1), vocabulary_size - 1, device=input.device
+                    ),
+                )
+
                 yield torch.cat(
                     [
                         input,
-                        torch.full(
-                            (input.size(0), 1), vocabulary_size - 1, device=input.device
-                        ),
+                        sep,
                         input,
                     ],
                     dim=1,
                 )
             yield input
 
-    train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
+    train_batches = add_memex(task.batches(split="train"), args.proba_memex)
 
     for input in train_batches:
         model.reset_inner_loss()