+ def add_memex(batches, proportion_memex):
+ for input in batches:
+ if torch.rand(1).item() < proportion_memex:
+ yield torch.cat(
+ [
+ input,
+ torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
+ ),
+ input,
+ ],
+ dim=1,
+ )
+ yield input
+
+ train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
+
+ for input in train_batches: