Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 13:42:19 +0000 (14:42 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 13:42:19 +0000 (14:42 +0100)
main.py
mygpt.py

diff --git a/main.py b/main.py
index fabebdd..18c0730 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -478,7 +478,7 @@ def get_lr(n_epoch, it):
 
         if it < args.nb_warmup_iter:
             return args.legacy_large_lr * it / args.nb_warmup_iter
-        elif it < args.legacy_nb_epoch_large_lr:
+        elif n_epoch < args.legacy_nb_epoch_large_lr:
             return args.legacy_large_lr
         else:
             return args.legacy_small_lr
index d1acf22..33c6fee 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -530,11 +530,11 @@ class Caterpillar(nn.Module):
         ######################################################################
         # Compute the recurrent state
 
-        # This is the Gating sequence that modulates if they key and
-        # values should be stored in one of the CH pairs of the
-        # current stack. The CH gating values are independent, which
-        # means that the same thing could be stored up to CH times or
-        # not at all
+        # This is the Gating sequence that modulates the storing of
+        # the new key and value in the CH pairs of the current
+        # stack. The CH gating values are independent, which means
+        # that the current K/V could be stored in all the pairs of the
+        # recurrent state, or not at all.
 
         G = (
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
@@ -552,10 +552,11 @@ class Caterpillar(nn.Module):
         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
 
-        # Here there is a trick: The parallel scan operates with a
-        # period of L, so we split the sequence indexing in two axes,
-        # the second of size CL, and run the parallel scan using the
-        # other alone as the sequence index.
+        # Here there is a trick: Since the stack at time t is computed
+        # by updating that at time t-L, the parallel scan operates
+        # with a period of L. To do so we split the time indexing in
+        # two axes, the second of size CL, and run the parallel scan
+        # using the other alone as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))