X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=91c885b5177619ab33d6b62ec7c87825ab1d2543;hb=7b2d37d9c7ffb10f9bd81ef6356ef7083614a380;hp=4d5077ae1b46b7398aed253c14129d5f6b451879;hpb=a3c32b845b6903fd290f2b09d5c53203ff112b79;p=mygptrnn.git diff --git a/main.py b/main.py index 4d5077a..91c885b 100755 --- a/main.py +++ b/main.py @@ -87,7 +87,7 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) -parser.add_argument("--proportion_memex", type=float, default=0) +parser.add_argument("--proba_memex", type=float, default=0) parser.add_argument("--dim_model", type=int, default=None) @@ -738,7 +738,7 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() -if args.proportion_memex > 0: +if args.proba_memex > 0: vocabulary_size += 1 log_string(f"vocabulary_size {vocabulary_size}") @@ -902,22 +902,26 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - def add_memex(batches, proportion_memex): + def add_memex(batches, proba_memex): for input in batches: - if torch.rand(1).item() < proportion_memex: + if torch.rand(1).item() < proba_memex: + sep = ( + torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device + ), + ) + yield torch.cat( [ input, - torch.full( - (input.size(0), 1), vocabulary_size - 1, device=input.device - ), + sep, input, ], dim=1, ) yield input - train_batches = add_memex(task.batches(split="train"), args.proportion_memex) + train_batches = add_memex(task.batches(split="train"), args.proba_memex) for input in train_batches: model.reset_inner_loss()