Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 17 Feb 2024 08:53:23 +0000 (09:53 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 17 Feb 2024 08:53:23 +0000 (09:53 +0100)
fridge
main.py

diff --git a/fridge b/fridge
index 143092c..a4d860b 100644 (file)
--- a/fridge
+++ b/fridge
@@ -335,3 +335,73 @@ class Calibrator:
         ) % k_star.size(0)
         k_star = k_star[l_barrel, t_barrel]
 
+
+######################################################################
+
+2024 Feb 15 23:10:50 (from main.py)
+
+
+def add_memex_v4(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            t = (
+                torch.arange(2 * input.size(1), device=input.device)[None, :]
+                .expand(input.size(0), -1)
+                .clone()
+            )
+
+            u = torch.rand(t.size(), device=t.device)
+            u[:, : input.size(1)] = 1.0
+            memex_v3_proba_fragment = 1 / 20
+            u = (u < memex_v3_proba_fragment).long()
+            v = u * torch.randint(input.size(1), u.size())
+            u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
+                :, : input.size(1) - 1
+            ] * input.size(1)
+            u = u.cumsum().clamp(min=0)
+
+            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+            caterpillar_length = args.nb_lines // args.caterpillar_height
+            u1 = (
+                u0
+                + torch.randint(
+                    caterpillar_length, (input.size(0), 1), device=input.device
+                )
+                + 1
+            )
+
+            m0 = (t < u0).long()
+            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+            m = (t < 0).long()
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            new_input = input[n, t.clamp(min=0)]
+            new_input = (1 - m) * new_input + m * (marker_token)
+
+            yield new_input
+
+        yield input
+
+
+
+######################################################################
+
+2024 Feb 16 17:07:48 (from main.py)
+
+                # ||gn + lambda * gm|| = max(||gn||,||gm||)
+                # ||gn||^2 + lambda<gn,gm> + lambda^2||gm||^2 = max(||gn||^2,||gm||^2)
+                # A = ||gm||^2 B = <gn,gm> C = ||gn||^2 - max(||gn||^2, ||gm||^2)
+
+######################################################################
+
+2024 Feb 16 17:07:51 (from main.py)
+
+                # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm)
+                # Delta = B*B - 4*A*C
+                # if(delta >= 0):
+                    # l = ( -B - sqrt(Delta))/(2*A)
+                # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||)
diff --git a/main.py b/main.py
index 6254807..2a90fd1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -453,6 +453,7 @@ except FileExistsError:
         exit(1)
 
 loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
 
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
@@ -530,7 +531,7 @@ def get_lr(n_epoch, it):
 ######################################################################
 
 
-def add_memex_v2(batches, memex_proba, marker_token):
+def add_memex_v1(batches, memex_proba, marker_token):
     for input in batches:
         if torch.rand(1).item() < memex_proba:
             t = (
@@ -561,61 +562,45 @@ def add_memex_v2(batches, memex_proba, marker_token):
             new_input = input[n, t.clamp(min=0)]
             new_input = (1 - m) * new_input + m * (marker_token)
 
-            yield new_input
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
+
+            yield new_input, memex_mask
 
         yield input
 
 
-def add_memex_v3(batches, memex_proba, marker_token):
+def add_memex_v2(batches, memex_proba):
     for input in batches:
         if torch.rand(1).item() < memex_proba:
-            t = (
-                torch.arange(2 * input.size(1), device=input.device)[None, :]
-                .expand(input.size(0), -1)
-                .clone()
+            t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+                input.size(0), -1
             )
-
-            u = torch.rand(t.size(), device=t.device)
-            u[:, : input.size(1)] = 1.0
-            memex_v3_proba_fragment = 1 / 20
-            u = (u < memex_v3_proba_fragment).long()
-            v = u * torch.randint(input.size(1), u.size())
-            u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
-                :, : input.size(1) - 1
-            ] * input.size(1)
-            u = u.cumsum().clamp(min=0)
-
-            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
-            caterpillar_length = args.nb_lines // args.caterpillar_height
-            u1 = (
-                u0
-                + torch.randint(
-                    caterpillar_length, (input.size(0), 1), device=input.device
-                )
-                + 1
+            t = t + torch.randint(
+                input.size(1) - t.size(1), (t.size(0), 1), device=t.device
             )
-
-            m0 = (t < u0).long()
-            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
-
-            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
-            m = (t < 0).long()
             n = torch.arange(input.size(0), device=input.device)[:, None].expand(
                 -1, t.size(1)
             )
 
-            new_input = input[n, t.clamp(min=0)]
-            new_input = (1 - m) * new_input + m * (marker_token)
+            flash = input[n, t]
+            new_input = torch.cat([input, flash], dim=1)
 
-            yield new_input
+            memex_mask = new_input.new_zeros(new_input.size())
+            memex_mask[:, input.size(1) :] = 1.0
 
-        yield input
+            yield new_input, memex_mask
+
+        else:
+            yield input
 
 
 ######################################################################
 
 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
 
+assert args.batch_size % args.physical_batch_size == 0
+
 
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
@@ -978,6 +963,21 @@ it = 0
 
 n_batch = 0
 
+
+def the_dot_products(value1, value2, params):
+    g1g1, g1g2, g2g2 = 0, 0, 0
+    for p in params:
+        g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+        g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+        g1g1 += g1.pow(2).sum()[None]
+        g2g2 += g2.pow(2).sum()[None]
+        g1g2 += (g1 * g2).sum()[None]
+    return torch.cat([g1g1, g1g2, g2g2])
+
+
+movave_dot_products = 0
+
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     if args.optim == "sgd":
         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@ -1003,7 +1003,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
     train_batches = add_memex_v2(
         batches=task.batches(split="train"),
         memex_proba=memex_proba,
-        marker_token=vocabulary_size - 1,
     )
 
     def add_none(it):
@@ -1015,11 +1014,37 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     for input in add_none(train_batches):
         if input is not None:
+            if type(input) is tuple:
+                input, memex_mask = input
+                memex_mask = memex_mask.to(device)
+            else:
+                memex_mask = None
+
             model.reset_inner_loss()
             input = input.to(device)
 
             output = model(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+
+            if memex_mask is None:
+                loss = F.cross_entropy(output.transpose(1, 2), input)
+            else:
+                loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                loss_regular = (loss * (1 - memex_mask)).mean()
+                loss_memex = (loss * memex_mask).mean()
+
+                if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01:
+                    dot_products = the_dot_products(
+                        loss_regular, loss_memex, model.parameters()
+                    )
+                    eps = 1e-3
+                    movave_dot_products = (
+                        1 - eps
+                    ) * movave_dot_products + eps * dot_products
+
+                grgr, grgm, gmgm = movave_dot_products
+                l = (max(grgr, gmgm) - grgr) / gmgm
+                loss = loss_regular + l * loss_memex
+
             inner_loss = model.get_inner_loss()
 
             acc_train_loss += loss.item() * input.size(0)
@@ -1047,10 +1072,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             optimizer.step()
             grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
             loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+            grgr, grgm, gmgm = movave_dot_products
+            l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm)
+            lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n")
             optimizer.zero_grad()
             nb_acc_samples = 0
-
-        n_batch += 1
+            n_batch += 1
 
     with torch.autograd.no_grad():
         model.eval()