From a1ae050705970007f965d2586c53e9bd262e46aa Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Feb 2024 13:28:27 +0100 Subject: [PATCH] Update. --- main.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 91c885b..d6845e8 100755 --- a/main.py +++ b/main.py @@ -87,7 +87,9 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) -parser.add_argument("--proba_memex", type=float, default=0) +parser.add_argument("--memex_proba", type=float, default=0) + +parser.add_argument("--memex_nb_epochs", type=float, default=1) parser.add_argument("--dim_model", type=int, default=None) @@ -738,7 +740,7 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() -if args.proba_memex > 0: +if args.memex_proba > 0: vocabulary_size += 1 log_string(f"vocabulary_size {vocabulary_size}") @@ -902,13 +904,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - def add_memex(batches, proba_memex): + def add_memex(batches, memex_proba): for input in batches: - if torch.rand(1).item() < proba_memex: - sep = ( - torch.full( - (input.size(0), 1), vocabulary_size - 1, device=input.device - ), + if torch.rand(1).item() < memex_proba: + sep = torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device ) yield torch.cat( @@ -921,7 +921,10 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) yield input - train_batches = add_memex(task.batches(split="train"), args.proba_memex) + train_batches = add_memex( + task.batches(split="train"), + args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0, + ) for input in train_batches: model.reset_inner_loss() -- 2.20.1