Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2023 06:29:10 +0000 (08:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2023 06:29:10 +0000 (08:29 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 7cb8d4f..db982ca 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -173,15 +173,27 @@ for n in vars(args):
 ######################################################################
 
 
+# ra_mask is boolean, with 1s on the values to generate
+
+
 def masked_inplace_autoregression(
-    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
 ):
-    for input, ar_mask in tqdm.tqdm(
-        zip(input.split(batch_size), ar_mask.split(batch_size)),
-        dynamic_ncols=True,
-        desc="autoregression",
-        total=input.size(0) // batch_size,
-    ):
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+    if progress_bar_desc is not None:
+        tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=input.size(0) // batch_size,
+        )
+    for input, ar_mask in batches:
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
@@ -317,6 +329,7 @@ class TaskPicoCLVR(Task):
                 input,
                 ar_masks,
                 forbidden_tokens,
+                progress_bar_desc=None,
                 device=self.device,
             )
             model.train(t)
@@ -975,9 +988,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
index b4446c6..6a12a5a 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -5,6 +5,11 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
 import math
 
 import torch
@@ -14,19 +19,6 @@ from torch.nn import functional as F
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -78,6 +70,19 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+##############################
+
+
 class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()