Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 06:51:11 +0000 (07:51 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 06:51:11 +0000 (07:51 +0100)
fridge
grid.py
main.py
mygpt.py
tasks.py

diff --git a/fridge b/fridge
index f87c1df..d09e92d 100644 (file)
--- a/fridge
+++ b/fridge
@@ -204,3 +204,91 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
                 + dropout_head * (1 - epsilon - G.detach())
                 - dropout_tail * G.detach()
             )
+
+######################################################################
+
+2024 Jan 18 07:39:29 (from mygpt.py)
+
+class Calibrator:
+    def __init__(self, w=None, b=None):
+        self.w = w
+        self.b = b
+        self.s, self.s_sq, self.n = 0, 0, 0
+        self.mean, self.std = 0, 0
+
+    def update(self, X):
+        X = X.detach()
+        self.s += X.sum(dim=0)
+        self.s_sq += X.pow(2).sum(dim=0)
+        self.n += X.size(0)
+
+    def moments(self):
+        mean = self.s / self.n
+        std = (self.s_sq / self.n - mean * mean).sqrt()
+        return mean, std
+
+    def normalize(self):
+        mean, std = self.moments()
+        if self.b is not None:
+            self.b.sub_(mean)
+        if self.w is not None:
+            self.w.div_(std)
+        result = mean - self.mean, std - self.std
+        self.mean, self.std = mean, std
+        self.s, self.s_sq, self.n = 0, 0, 0
+        return result
+
+
+
+######################################################################
+
+2024 Jan 18 07:39:34 (from mygpt.py)
+
+        # self.calibrator_G = Calibrator()
+        # self.calibrator_rec_V = Calibrator()
+        # self.calibrator_rec_K = Calibrator()
+
+
+######################################################################
+
+2024 Jan 18 07:39:37 (from mygpt.py)
+
+        # self.calibrator_G.update(G.reshape(-1, G.size(-1)))
+
+
+######################################################################
+
+2024 Jan 18 07:39:42 (from mygpt.py)
+
+        # self.calibrator_rec_V.update(
+        # next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
+        # )
+        # self.calibrator_rec_K.update(
+        # next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
+        # )
+
+
+######################################################################
+
+2024 Jan 18 07:47:12 (from mygpt.py)
+
+        ######################################################################
+        # Roll the gating indexes
+
+        # warnings.warn("rotating barrel", RuntimeWarning)
+
+        # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
+        # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
+        # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
+        # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+
+
+######################################################################
+
+2024 Jan 18 07:47:25 (from mygpt.py)
+
+        # warnings.warn("harmonic recurrence", RuntimeWarning)
+        # har = torch.arange(t0, t1, device = G.device).float() + 1
+        # A = har / (har + 1)
+        # G = G / har
+
diff --git a/grid.py b/grid.py
index 268f4ee..f9f1557 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -9,10 +9,6 @@ import math
 import torch, torchvision
 import torch.nn.functional as F
 
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
 ######################################################################
 
 
@@ -23,20 +19,160 @@ class GridFactory:
         max_nb_items=4,
         max_nb_transformations=3,
         nb_questions=4,
+        nb_shapes=6,
+        nb_colors=6,
     ):
         assert size % 2 == 0
         self.size = size
         self.max_nb_items = max_nb_items
         self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
+        self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)]
+        self.name_colors = [
+            "red",
+            "yellow",
+            "blue",
+            "green",
+            "white",
+            "black",
+            "maroon",
+            "dark_red",
+            "brown",
+            "firebrick",
+            "crimson",
+            "tomato",
+            "coral",
+            "indian_red",
+            "light_coral",
+            "dark_salmon",
+            "salmon",
+            "light_salmon",
+            "orange_red",
+            "dark_orange",
+            "orange",
+            "gold",
+            "dark_golden_rod",
+            "golden_rod",
+            "pale_golden_rod",
+            "dark_khaki",
+            "khaki",
+            "olive",
+            "yellow_green",
+            "dark_olive_green",
+            "olive_drab",
+            "lawn_green",
+            "chartreuse",
+            "green_yellow",
+            "dark_green",
+            "forest_green",
+            "lime",
+            "lime_green",
+            "light_green",
+            "pale_green",
+            "dark_sea_green",
+            "medium_spring_green",
+            "spring_green",
+            "sea_green",
+            "medium_aqua_marine",
+            "medium_sea_green",
+            "light_sea_green",
+            "dark_slate_gray",
+            "teal",
+            "dark_cyan",
+            "aqua",
+            "cyan",
+            "light_cyan",
+            "dark_turquoise",
+            "turquoise",
+            "medium_turquoise",
+            "pale_turquoise",
+            "aqua_marine",
+            "powder_blue",
+            "cadet_blue",
+            "steel_blue",
+            "corn_flower_blue",
+            "deep_sky_blue",
+            "dodger_blue",
+            "light_blue",
+            "sky_blue",
+            "light_sky_blue",
+            "midnight_blue",
+            "navy",
+            "dark_blue",
+            "medium_blue",
+            "royal_blue",
+            "blue_violet",
+            "indigo",
+            "dark_slate_blue",
+            "slate_blue",
+            "medium_slate_blue",
+            "medium_purple",
+            "dark_magenta",
+            "dark_violet",
+            "dark_orchid",
+            "medium_orchid",
+            "purple",
+            "thistle",
+            "plum",
+            "violet",
+            "magenta",
+            "orchid",
+            "medium_violet_red",
+            "pale_violet_red",
+            "deep_pink",
+            "hot_pink",
+            "light_pink",
+            "pink",
+            "antique_white",
+            "beige",
+            "bisque",
+            "blanched_almond",
+            "wheat",
+            "corn_silk",
+            "lemon_chiffon",
+            "light_golden_rod_yellow",
+            "light_yellow",
+            "saddle_brown",
+            "sienna",
+            "chocolate",
+            "peru",
+            "sandy_brown",
+            "burly_wood",
+            "tan",
+            "rosy_brown",
+            "moccasin",
+            "navajo_white",
+            "peach_puff",
+            "misty_rose",
+            "lavender_blush",
+            "linen",
+            "old_lace",
+            "papaya_whip",
+            "sea_shell",
+            "mint_cream",
+            "slate_gray",
+            "light_slate_gray",
+            "light_steel_blue",
+            "lavender",
+            "floral_white",
+            "alice_blue",
+            "ghost_white",
+            "honeydew",
+            "ivory",
+            "azure",
+            "snow",
+            "silver",
+            "gainsboro",
+            "white_smoke",
+        ][:nb_colors]
 
     def generate_scene(self):
         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
         col = torch.full((self.size * self.size,), -1)
         shp = torch.full((self.size * self.size,), -1)
-        a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
-        col[:nb_items] = a % len(name_colors)
-        shp[:nb_items] = a // len(name_colors)
+        a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+        col[:nb_items] = a % len(self.name_colors)
+        shp[:nb_items] = a // len(self.name_colors)
         i = torch.randperm(self.size * self.size)
         col = col[i]
         shp = shp[i]
@@ -76,12 +212,15 @@ class GridFactory:
         # for i in range(self.size):
         # for j in range(self.size):
         # if col[i,j] >= 0:
-        # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+        # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
 
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+                    print(
+                        f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+                        end="",
+                    )
                 elif j == 0:
                     print(" +", end="")
                 else:
@@ -103,7 +242,7 @@ class GridFactory:
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+                    n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
                     properties += [f"a {n} at {i} {j}"]
 
         return properties
@@ -116,7 +255,9 @@ class GridFactory:
         for i1 in range(self.size):
             for j1 in range(self.size):
                 if col[i1, j1] >= 0:
-                    n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+                    n1 = (
+                        f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+                    )
                     properties += [f"there is a {n1}"]
                     if i1 < self.size // 2:
                         properties += [f"a {n1} is in the top half"]
@@ -129,7 +270,7 @@ class GridFactory:
                     for i2 in range(self.size):
                         for j2 in range(self.size):
                             if col[i2, j2] >= 0:
-                                n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+                                n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
                                 if i1 > i2:
                                     properties += [f"a {n1} is below a {n2}"]
                                 if i1 < i2:
diff --git a/main.py b/main.py
index 04e5652..79841f3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -133,6 +133,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 parser.add_argument("--grid_size", type=int, default=6)
 
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
 ##############################
 # picoclvr options
 
@@ -701,6 +705,8 @@ elif args.task == "grid":
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
         size=args.grid_size,
+        nb_shapes=args.grid_nb_shapes,
+        nb_colors=args.grid_nb_colors,
         logger=log_string,
         device=device_data,
     )
@@ -835,21 +841,22 @@ if args.max_percents_of_test_in_train >= 0:
 
 ##############################
 
-for input in task.batches(split="train", desc="calibrate"):
-    input = input.to(device)
-    output = model(mygpt.BracketedSequence(input)).x
-
-for n, m in model.named_modules():
-    for a in dir(m):
-        x = getattr(m, a)
-        if isinstance(x, mygpt.Calibrator):
-            print(f"####### ${n} | ${a} ########################")
-            mean, std = x.moments()
-            print("mean\n", mean, "\n")
-            print("std\n", std, "\n")
-            print(f"############################################\n\n")
-
-exit(0)
+if "calibrate" in sup_args:
+    for input in task.batches(split="train", desc="calibrate"):
+        input = input.to(device)
+        output = model(mygpt.BracketedSequence(input)).x
+
+    for n, m in model.named_modules():
+        for a in dir(m):
+            x = getattr(m, a)
+            if isinstance(x, mygpt.Calibrator):
+                print(f"####### ${n} | ${a} ########################")
+                mean, std = x.moments()
+                print("mean\n", mean, "\n")
+                print("std\n", std, "\n")
+                print(f"############################################\n\n")
+
+    exit(0)
 
 ##############################
 
index aded796..a27b99e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -126,7 +126,6 @@ class AddPositionalEncoding(nn.Module):
 
 import pscan
 
-
 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
 
 
@@ -147,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2):
     return Y
 
 
+def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
+    with torch.no_grad():
+        s_A, s_X = 0, 0
+        for t in range(X.size(dim) - 1, 0, -1):
+            delta = (grad_Y[t] - s_A) / A[t].grad
+            s_A += A[t].grad * delta
+            A[t].grad = delta
+            delta = (grad_Y[t] - s_X) / X[t].grad
+            s_X += X[t].grad * delta
+            X[t].grad = delta
+
+
 def pscan_shape(A, X, Y_init):
     s = X.size()
     A = A.reshape(-1, s[-2])
@@ -464,36 +475,6 @@ def moving_window(x, dim, win_dim, win_size):
 ##############################
 
 
-class Calibrator:
-    def __init__(self, w=None, b=None):
-        self.w = w
-        self.b = b
-        self.s, self.s_sq, self.n = 0, 0, 0
-        self.mean, self.std = 0, 0
-
-    def update(self, X):
-        X = X.detach()
-        self.s += X.sum(dim=0)
-        self.s_sq += X.pow(2).sum(dim=0)
-        self.n += X.size(0)
-
-    def moments(self):
-        mean = self.s / self.n
-        std = (self.s_sq / self.n - mean * mean).sqrt()
-        return mean, std
-
-    def normalize(self):
-        mean, std = self.moments()
-        if self.b is not None:
-            self.b.sub_(mean)
-        if self.w is not None:
-            self.w.div_(std)
-        result = mean - self.mean, std - self.std
-        self.mean, self.std = mean, std
-        self.s, self.s_sq, self.n = 0, 0, 0
-        return result
-
-
 class Caterpillar(nn.Module):
     def __init__(
         self,
@@ -561,10 +542,6 @@ class Caterpillar(nn.Module):
             dim_v,
         )
 
-        self.calibrator_G = Calibrator()
-        self.calibrator_rec_V = Calibrator()
-        self.calibrator_rec_K = Calibrator()
-
     def reset_inner_loss(self):
         self.acc_attention = 0
         self.acc_nb = 0
@@ -620,8 +597,6 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
-        self.calibrator_G.update(G.reshape(-1, G.size(-1)))
-
         # warnings.warn("softmax gating", RuntimeWarning)
 
         # G = (
@@ -646,64 +621,47 @@ class Caterpillar(nn.Module):
 
             G = alpha * (1 - kill)
 
-        ######################################################################
-        # Clip the gating to avoid values greater than 1 when several
-        # heads hit the same row
+        def recurrence(G, V, K):
+            # Clip the gating to avoid values greater than 1 when several
+            # heads hit the same row
 
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
+            G = G / G.sum(1, keepdim=True).clamp(min=1)
 
-        ######################################################################
-        # Roll the gating indexes
-
-        # warnings.warn("rotating barrel", RuntimeWarning)
+            # We prepare the arguments for the parallel scan
 
-        # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
-        # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
-        # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
-        # G = G.gather(dim=2, index=r_barrel.expand_as(G))
+            A = 1 - G.sum(1)
 
-        # We prepare the arguments for the parallel scan
+            gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+            gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
-        A = 1 - G.sum(1)
+            # We start from cached values, which matters in inference
 
-        # warnings.warn("harmonic recurrence", RuntimeWarning)
-        # har = torch.arange(t0, t1, device = G.device).float() + 1
-        # A = har / (har + 1)
-        # G = G / har
+            init_rec_V = self.rec_V[:, :, t0 - L : t0]
+            init_rec_K = self.rec_K[:, :, t0 - L : t0]
 
-        gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
-        gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
+            # Associative scan
 
-        # We start from cached values, which matters in inference
+            # Here there is a trick: Since the stack at position t is
+            # computed by updating that at position t-L, the parallel
+            # scan operates with a period of L. To do so we split the
+            # sequence indexing in two axes, the second of size L, and
+            # run the parallel scan using the first as the sequence index.
 
-        init_rec_V = self.rec_V[:, :, t0 - L : t0]
-        init_rec_K = self.rec_K[:, :, t0 - L : t0]
-
-        #################################################################
-        # Associative scan
+            A = A.unflatten(2, (-1, L))
+            gated_V = gated_V.unflatten(2, (-1, L))
+            gated_K = gated_K.unflatten(2, (-1, L))
 
-        # Here there is a trick: Since the stack at position t is
-        # computed by updating that at position t-L, the parallel
-        # scan operates with a period of L. To do so we split the
-        # sequence indexing in two axes, the second of size L, and
-        # run the parallel scan using the first as the sequence index.
+            next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+            next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
 
-        A = A.unflatten(2, (-1, L))
-        gated_V = gated_V.unflatten(2, (-1, L))
-        gated_K = gated_K.unflatten(2, (-1, L))
+            next_V = next_V.flatten(2, 3)
+            next_K = next_K.flatten(2, 3)
 
-        next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
-        next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
+            return next_V, next_K
 
-        next_V = next_V.flatten(2, 3)
-        next_K = next_K.flatten(2, 3)
+        #################################################################
 
-        self.calibrator_rec_V.update(
-            next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
-        )
-        self.calibrator_rec_K.update(
-            next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
-        )
+        next_V, next_K = recurrence(G, V, K)
 
         self.rec_V[:, :, t0:t1] = next_V
         self.rec_K[:, :, t0:t1] = next_K
index 4777a11..727b196 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1473,6 +1473,8 @@ class Grid(Task):
         nb_test_samples,
         batch_size,
         size,
+        nb_shapes,
+        nb_colors,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1480,7 +1482,9 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(size=size)
+        self.grid_factory = grid.GridFactory(
+            size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+        )
 
         if logger is not None:
             logger(