Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 6254807..a587e96 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -453,6 +453,7 @@ except FileExistsError:
         exit(1)
 
 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
 
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
@@ -530,7 +531,7 @@ def get_lr(n_epoch, it):
 ######################################################################
 
 
-def add_memex_v2(batches, memex_proba, marker_token):
+def add_memex_v1(batches, memex_proba, marker_token):
     for input in batches:
         if torch.rand(1).item() < memex_proba:
             t = (
@@ -561,61 +562,101 @@ def add_memex_v2(batches, memex_proba, marker_token):
             new_input = input[n, t.clamp(min=0)]
             new_input = (1 - m) * new_input + m * (marker_token)
 
-            yield new_input
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
+
+            yield new_input, memex_mask
 
         yield input
 
 
-def add_memex_v3(batches, memex_proba, marker_token):
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
     for input in batches:
         if torch.rand(1).item() < memex_proba:
-            t = (
-                torch.arange(2 * input.size(1), device=input.device)[None, :]
-                .expand(input.size(0), -1)
-                .clone()
+            t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+                input.size(0), -1
+            )
+            t = t + torch.randint(
+                input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+            )
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
             )
 
-            u = torch.rand(t.size(), device=t.device)
-            u[:, : input.size(1)] = 1.0
-            memex_v3_proba_fragment = 1 / 20
-            u = (u < memex_v3_proba_fragment).long()
-            v = u * torch.randint(input.size(1), u.size())
-            u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
-                :, : input.size(1) - 1
-            ] * input.size(1)
-            u = u.cumsum().clamp(min=0)
+            flash = input[n, t]
+            new_input = torch.cat([input, flash], dim=1)
 
-            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
-            caterpillar_length = args.nb_lines // args.caterpillar_height
-            u1 = (
-                u0
-                + torch.randint(
-                    caterpillar_length, (input.size(0), 1), device=input.device
-                )
-                + 1
-            )
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
 
-            m0 = (t < u0).long()
-            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+            yield new_input, memex_mask
 
-            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
-            m = (t < 0).long()
+        else:
+            yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            memex_len = input.size(1) // 4
+
+            t = torch.arange(input.size(1) + memex_len, device=input.device)[
+                None, :
+            ].expand(input.size(0), -1)
             n = torch.arange(input.size(0), device=input.device)[:, None].expand(
                 -1, t.size(1)
             )
 
-            new_input = input[n, t.clamp(min=0)]
-            new_input = (1 - m) * new_input + m * (marker_token)
+            # Call me the tensor-spaghetti master
 
-            yield new_input
+            trigger = torch.rand(t.size(), device=t.device)
+            trigger[:, -memex_len:] = 2.0
+            trigger[:, 0] = 2.0
+            trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+            memex_mask = trigger.clone()
+            memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+            memex_mask = memex_mask.cumsum(dim=1)
 
-        yield input
+            u = 1 - memex_mask
+            u[:, 0] = 0
+            u = u.cumsum(dim=1)
+            assert u.min() == 0
+            assert u.max() == input.size(1) - 1
+
+            v = (
+                (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+                + torch.randint(
+                    input.size(1) - memex_len, (input.size(0), 1), device=t.device
+                )
+            ) * memex_mask
+            assert v.min() >= 0
+            assert v.max() < input.size(1)
+            u = u * (1 - memex_mask) + v * memex_mask
+
+            new_input = input[n, u]
+            assert input.max() < vocabulary_size
+            assert new_input.max() < vocabulary_size
+            limits = trigger.clone()
+            limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+            assert limits.min() == 0
+            assert limits.max() == 1
+            new_input = new_input * (1 - limits) + marker_token * limits
+            assert marker_token < vocabulary_size
+            assert new_input.max() < vocabulary_size
+
+            yield new_input, memex_mask
+
+        else:
+            yield input
 
 
 ######################################################################
 
 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
 
+assert args.batch_size % args.physical_batch_size == 0
+
 
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
@@ -829,6 +870,7 @@ log_string(f"device {device}")
 vocabulary_size = task.vocabulary_size()
 
 if args.memex_proba > 0:
+    memex_marker = vocabulary_size
     vocabulary_size += 1
 
 log_string(f"vocabulary_size {vocabulary_size}")
@@ -978,6 +1020,32 @@ it = 0
 
 n_batch = 0
 
+
+def the_dot_products(value1, value2, params):
+    g1g1, g1g2, g2g2 = 0, 0, 0
+    for p in params:
+        g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+        g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+        g1g1 += g1.pow(2).sum()[None]
+        g2g2 += g2.pow(2).sum()[None]
+        g1g2 += (g1 * g2).sum()[None]
+    return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+    for p in params:
+        g = torch.autograd.grad(value, p, retain_graph=True)[0]
+        ag = getattr(p, name) if hasattr(p, name) else 0
+        setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+    s = 0
+    for p in params:
+        s += getattr(p, name).pow(2).sum()
+    return s
+
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     if args.optim == "sgd":
         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@ -1000,10 +1068,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     log_string(f"memex_proba {memex_proba}")
 
-    train_batches = add_memex_v2(
+    warnings.warn("memex v3", RuntimeWarning)
+    train_batches = add_memex_v3(
         batches=task.batches(split="train"),
         memex_proba=memex_proba,
-        marker_token=vocabulary_size - 1,
+        marker_token=memex_marker,
     )
 
     def add_none(it):
@@ -1015,11 +1084,35 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     for input in add_none(train_batches):
         if input is not None:
+            if type(input) is tuple:
+                input, memex_mask = input
+                memex_mask = memex_mask.to(device)
+            else:
+                memex_mask = None
+
             model.reset_inner_loss()
             input = input.to(device)
 
             output = model(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+
+            if memex_mask is None:
+                loss = F.cross_entropy(output.transpose(1, 2), input)
+            else:
+                loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                loss_regular = (loss * (1 - memex_mask)).mean()
+                loss_memex = (loss * memex_mask).mean()
+
+                if it < 100 or torch.rand(1) < 0.01:
+                    update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+                    update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+                    norm_regular = norm(model.parameters(), "grad_regular")
+                    norm_memex = norm(model.parameters(), "grad_memex")
+                    l_memex = (
+                        max(norm_regular, norm_memex) - norm_regular
+                    ) / norm_memex
+
+                loss = loss_regular + l_memex * loss_memex
+
             inner_loss = model.get_inner_loss()
 
             acc_train_loss += loss.item() * input.size(0)
@@ -1047,10 +1140,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             optimizer.step()
             grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
             loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+            lambda_file.write(
+                f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+            )
             optimizer.zero_grad()
             nb_acc_samples = 0
-
-        n_batch += 1
+            n_batch += 1
 
     with torch.autograd.no_grad():
         model.eval()