Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 18c0730..a587e96 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -11,18 +11,23 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
 import ffutils
 import mygpt, tasks, problems
 
 ######################################################################
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
 
-######################################################################
+def str2bool(x):
+    x = x.lower()
+    if x in {"1", "true", "yes"}:
+        return True
+    elif x in {"0", "false", "no"}:
+        return False
+    else:
+        raise ValueError
+
 
 parser = argparse.ArgumentParser(
     description="An implementation of GPT with cache.",
@@ -44,11 +49,15 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
 ########################################
 
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--physical_batch_size", type=int, default=None)
 
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=None)
 
@@ -68,7 +77,7 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5)
 
 # legacy
 
-parser.add_argument("--legacy_lr_schedule", action="store_true", default=False)
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
 
 parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
 
@@ -82,6 +91,10 @@ parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--attention", type=str, default=None)
 
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
+
 parser.add_argument("--dim_model", type=int, default=None)
 
 parser.add_argument("--dim_keys", type=int, default=None)
@@ -94,9 +107,13 @@ parser.add_argument("--nb_lines", type=int, default=None)
 
 parser.add_argument("--caterpillar_height", type=int, default=None)
 
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
 
-parser.add_argument("--dim_rec_v", type=int, default=None)
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
@@ -108,7 +125,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument("--overwrite_results", action="store_true", default=False)
+parser.add_argument("--continue_training", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
@@ -130,6 +147,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 parser.add_argument("--grid_size", type=int, default=6)
 
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
 ##############################
 # picoclvr options
 
@@ -199,109 +220,119 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default
 
 ######################################################################
 
-args = parser.parse_args()
+args = parser.parse_args()
 
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
 
 if args.result_dir is None:
     args.result_dir = f"results_{args.task}_{args.model}"
 
 ######################################################################
 
+if not args.force_cpu and torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
 default_task_args = {
     "addition": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "byheart": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "expr": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 2500000,
         "nb_test_samples": 10000,
     },
     "grid": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "qmlp": {
         "model": "37M",
-        "batch_size": 10,
+        "physical_batch_size": 10,
         "nb_train_samples": 100000,
         "nb_test_samples": 1000,
     },
     "guessop": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 1000000,
         "nb_test_samples": 10000,
     },
     "learnop": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "maze": {
         "model": "37M",
-        "batch_size": 5,
+        "physical_batch_size": 5,
         "nb_train_samples": 100000,
         "nb_test_samples": 10000,
     },
     "picoclvr": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "rpl": {
         "model": "352M",
-        "batch_size": 5,
+        "physical_batch_size": 5,
         "nb_train_samples": 2500000,
         "nb_test_samples": 10000,
     },
     "snake": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "stack": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 100000,
         "nb_test_samples": 1000,
     },
     "twotargets": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "memory": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 25000,
         "nb_test_samples": 10000,
     },
     "mixing": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "mnist": {
         "model": "37M",
-        "batch_size": 10,
+        "physical_batch_size": 5,
         "nb_train_samples": 60000,
         "nb_test_samples": 10000,
     },
@@ -321,7 +352,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 32,
         "nb_heads": 2,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "17K-C": {
@@ -332,7 +362,6 @@ default_model_args = {
         "nb_heads": 2,
         "nb_lines": 16,
         "caterpillar_height": 4,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "4M": {
@@ -341,7 +370,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 1024,
         "nb_heads": 4,
-        "dim_rec_v": 64,
         "nb_blocks": 6,
     },
     "4M-C": {
@@ -352,7 +380,6 @@ default_model_args = {
         "nb_heads": 4,
         "nb_lines": 32,
         "caterpillar_height": 4,
-        "dim_rec_v": 64,  # dim_model / nb_heads
         "nb_blocks": 6,
     },
     "37M": {
@@ -361,7 +388,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "37M-C": {
@@ -372,7 +398,6 @@ default_model_args = {
         "nb_heads": 8,
         "nb_lines": 256,
         "caterpillar_height": 32,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "122M": {
@@ -381,7 +406,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "122M-C": {
@@ -391,7 +415,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "352M": {
@@ -400,7 +423,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
     "352M-C": {
@@ -410,7 +432,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
 }
@@ -427,10 +448,13 @@ else:
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
-    if not args.overwrite_results:
+    if not args.continue_training:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
+
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
 if args.seed >= 0:
@@ -467,6 +491,9 @@ log_string(f"argv {' '.join(sys.argv)}")
 for n in vars(args):
     log_string(f"args.{n} {getattr(args, n)}")
 
+for k, v in sup_args.items():
+    log_string(f'sup_args["{k}"] "{v}"')
+
 
 ######################################################################
 
@@ -504,6 +531,133 @@ def get_lr(n_epoch, it):
 ######################################################################
 
 
+def add_memex_v1(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            t = (
+                torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+                .expand(input.size(0), -1)
+                .clone()
+            )
+
+            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+            caterpillar_length = args.nb_lines // args.caterpillar_height
+            u1 = (
+                u0
+                + torch.randint(
+                    caterpillar_length, (input.size(0), 1), device=input.device
+                )
+                + 1
+            )
+
+            m0 = (t < u0).long()
+            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+            m = (t < 0).long()
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            new_input = input[n, t.clamp(min=0)]
+            new_input = (1 - m) * new_input + m * (marker_token)
+
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
+
+            yield new_input, memex_mask
+
+        yield input
+
+
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+                input.size(0), -1
+            )
+            t = t + torch.randint(
+                input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+            )
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            flash = input[n, t]
+            new_input = torch.cat([input, flash], dim=1)
+
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
+
+            yield new_input, memex_mask
+
+        else:
+            yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            memex_len = input.size(1) // 4
+
+            t = torch.arange(input.size(1) + memex_len, device=input.device)[
+                None, :
+            ].expand(input.size(0), -1)
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            # Call me the tensor-spaghetti master
+
+            trigger = torch.rand(t.size(), device=t.device)
+            trigger[:, -memex_len:] = 2.0
+            trigger[:, 0] = 2.0
+            trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+            memex_mask = trigger.clone()
+            memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+            memex_mask = memex_mask.cumsum(dim=1)
+
+            u = 1 - memex_mask
+            u[:, 0] = 0
+            u = u.cumsum(dim=1)
+            assert u.min() == 0
+            assert u.max() == input.size(1) - 1
+
+            v = (
+                (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+                + torch.randint(
+                    input.size(1) - memex_len, (input.size(0), 1), device=t.device
+                )
+            ) * memex_mask
+            assert v.min() >= 0
+            assert v.max() < input.size(1)
+            u = u * (1 - memex_mask) + v * memex_mask
+
+            new_input = input[n, u]
+            assert input.max() < vocabulary_size
+            assert new_input.max() < vocabulary_size
+            limits = trigger.clone()
+            limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+            assert limits.min() == 0
+            assert limits.max() == 1
+            new_input = new_input * (1 - limits) + marker_token * limits
+            assert marker_token < vocabulary_size
+            assert new_input.max() < vocabulary_size
+
+            yield new_input, memex_mask
+
+        else:
+            yield input
+
+
+######################################################################
+
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+assert args.batch_size % args.physical_batch_size == 0
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -529,7 +683,7 @@ if args.task == "byheart":
         problem=problems.ProblemByHeart(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -540,7 +694,7 @@ elif args.task == "learnop":
         problem=problems.ProblemLearnOperator(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -551,7 +705,7 @@ elif args.task == "guessop":
         problem=problems.ProblemGuessOperator(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -562,7 +716,7 @@ elif args.task == "twotargets":
         problem=problems.ProblemTwoTargets(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -572,7 +726,7 @@ elif args.task == "memory":
         problem=problems.ProblemMemory(len_total=args.memory_len_total),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -584,7 +738,7 @@ elif args.task == "mixing":
         ),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -594,7 +748,7 @@ elif args.task == "addition":
         problem=problems.ProblemAddition(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -603,7 +757,7 @@ elif args.task == "picoclvr":
     task = tasks.PicoCLVR(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.picoclvr_height,
         width=args.picoclvr_width,
         nb_colors=args.picoclvr_nb_colors,
@@ -617,7 +771,7 @@ elif args.task == "mnist":
     task = tasks.MNIST(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         device=device_data,
     )
 
@@ -625,7 +779,7 @@ elif args.task == "maze":
     task = tasks.Maze(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.maze_height,
         width=args.maze_width,
         nb_walls=args.maze_nb_walls,
@@ -636,7 +790,7 @@ elif args.task == "snake":
     task = tasks.Snake(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.snake_height,
         width=args.snake_width,
         nb_colors=args.snake_nb_colors,
@@ -649,7 +803,7 @@ elif args.task == "stack":
     task = tasks.Stack(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         nb_steps=args.stack_nb_steps,
         nb_stacks=args.stack_nb_stacks,
@@ -666,7 +820,7 @@ elif args.task == "expr":
         sequence_length=args.expr_sequence_length,
         operand_max=args.expr_operand_max,
         result_max=args.expr_result_max,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         device=device_data,
     )
 
@@ -674,7 +828,7 @@ elif args.task == "rpl":
     task = tasks.RPL(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         nb_starting_values=args.rpl_nb_starting_values,
         max_input=args.rpl_max_input,
         prog_len=args.rpl_prog_len,
@@ -688,8 +842,10 @@ elif args.task == "grid":
     task = tasks.Grid(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         size=args.grid_size,
+        nb_shapes=args.grid_nb_shapes,
+        nb_colors=args.grid_nb_colors,
         logger=log_string,
         device=device_data,
     )
@@ -698,7 +854,7 @@ elif args.task == "qmlp":
     task = tasks.QMLP(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         result_dir=args.result_dir,
         logger=log_string,
         device=device_data,
@@ -713,6 +869,10 @@ log_string(f"device {device}")
 
 vocabulary_size = task.vocabulary_size()
 
+if args.memex_proba > 0:
+    memex_marker = vocabulary_size
+    vocabulary_size += 1
+
 log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
@@ -725,11 +885,12 @@ model = mygpt.MyGPT(
     nb_heads=args.nb_heads,
     nb_lines=args.nb_lines,
     caterpillar_height=args.caterpillar_height,
-    dim_rec_v=args.dim_rec_v,
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
     attention_layer=args.attention,
+    logger=log_string,
+    args=args,
 )
 
 model.to(device)
@@ -823,6 +984,25 @@ if args.max_percents_of_test_in_train >= 0:
 
 ##############################
 
+if "calibrate" in sup_args:
+    for input in task.batches(split="train", desc="calibrate"):
+        input = input.to(device)
+        output = model(mygpt.BracketedSequence(input)).x
+
+    for n, m in model.named_modules():
+        for a in dir(m):
+            x = getattr(m, a)
+            if isinstance(x, mygpt.Calibrator):
+                print(f"####### ${n} | ${a} ########################")
+                mean, std = x.moments()
+                print("mean\n", mean, "\n")
+                print("std\n", std, "\n")
+                print(f"############################################\n\n")
+
+    exit(0)
+
+##############################
+
 nb_samples_seen = 0
 
 if nb_epochs_finished >= nb_epochs:
@@ -834,10 +1014,38 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
 
 it = 0
 
+n_batch = 0
+
+
+def the_dot_products(value1, value2, params):
+    g1g1, g1g2, g2g2 = 0, 0, 0
+    for p in params:
+        g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+        g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+        g1g1 += g1.pow(2).sum()[None]
+        g2g2 += g2.pow(2).sum()[None]
+        g1g2 += (g1 * g2).sum()[None]
+    return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+    for p in params:
+        g = torch.autograd.grad(value, p, retain_graph=True)[0]
+        ag = getattr(p, name) if hasattr(p, name) else 0
+        setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+    s = 0
+    for p in params:
+        s += getattr(p, name).pow(2).sum()
+    return s
+
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     if args.optim == "sgd":
         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@ -852,32 +1060,92 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
 
-    for input in task.batches(split="train"):
-        model.reset_inner_loss()
-        input = input.to(device)
+    memex_proba = (
+        args.memex_proba
+        if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+        else 0.0
+    )
 
-        output = model(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), input)
-        inner_loss = model.get_inner_loss()
+    log_string(f"memex_proba {memex_proba}")
 
-        acc_train_loss += loss.item() * input.size(0)
-        acc_train_inner_loss += inner_loss.item() * input.size(0)
+    warnings.warn("memex v3", RuntimeWarning)
+    train_batches = add_memex_v3(
+        batches=task.batches(split="train"),
+        memex_proba=memex_proba,
+        marker_token=memex_marker,
+    )
 
-        nb_train_samples += input.size(0)
-        nb_samples_seen += input.size(0)
+    def add_none(it):
+        for x in it:
+            yield x
+        yield None
+
+    nb_acc_samples = 0
+
+    for input in add_none(train_batches):
+        if input is not None:
+            if type(input) is tuple:
+                input, memex_mask = input
+                memex_mask = memex_mask.to(device)
+            else:
+                memex_mask = None
+
+            model.reset_inner_loss()
+            input = input.to(device)
+
+            output = model(mygpt.BracketedSequence(input)).x
 
-        total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+            if memex_mask is None:
+                loss = F.cross_entropy(output.transpose(1, 2), input)
+            else:
+                loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                loss_regular = (loss * (1 - memex_mask)).mean()
+                loss_memex = (loss * memex_mask).mean()
 
-        it += 1
-        lr = get_lr(n_epoch, it)
-        for param_group in optimizer.param_groups:
-            param_group["lr"] = lr
+                if it < 100 or torch.rand(1) < 0.01:
+                    update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+                    update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+                    norm_regular = norm(model.parameters(), "grad_regular")
+                    norm_memex = norm(model.parameters(), "grad_memex")
+                    l_memex = (
+                        max(norm_regular, norm_memex) - norm_regular
+                    ) / norm_memex
 
-        # log_string(f"learning_rate {lr}")
+                loss = loss_regular + l_memex * loss_memex
 
-        optimizer.zero_grad()
-        total_loss.backward()
-        optimizer.step()
+            inner_loss = model.get_inner_loss()
+
+            acc_train_loss += loss.item() * input.size(0)
+            acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+            nb_train_samples += input.size(0)
+            nb_samples_seen += input.size(0)
+
+            total_loss = loss + (
+                args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+            )
+
+            it += 1
+            lr = get_lr(n_epoch, it)
+            for param_group in optimizer.param_groups:
+                param_group["lr"] = lr
+
+                # log_string(f"learning_rate {lr}")
+
+            total_loss.backward()
+            nb_acc_samples += input.size(0)
+
+        if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+            assert nb_acc_samples <= args.batch_size
+            optimizer.step()
+            grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+            loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+            lambda_file.write(
+                f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+            )
+            optimizer.zero_grad()
+            nb_acc_samples = 0
+            n_batch += 1
 
     with torch.autograd.no_grad():
         model.eval()
@@ -912,10 +1180,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
         time_pred_result = time_current_result
 
     checkpoint = {