Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jan 2024 20:49:52 +0000 (21:49 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jan 2024 20:49:52 +0000 (21:49 +0100)
mygpt.py

index 9d3abb6..633ad64 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -569,17 +569,20 @@ class Caterpillar(nn.Module):
         # Roll the gating indexes
 
         warnings.warn("rotating barrel", RuntimeWarning)
+
+        # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}")
+
         n_barrel = torch.arange(N, device=G.device)[:, None, None, None]
         h_barrel = torch.arange(H, device=G.device)[None, :, None, None]
         r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
         t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
-        r_barrel = (r_barrel + t_barrel + t0) % R
-
-        # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}")
+        r_barrel = (r_barrel + (t_barrel + t0) // L) % R
 
+        # GG = G.gather(dim=2,index=r_barrel)
         G = G[n_barrel, h_barrel, r_barrel, t_barrel]
 
-        # print(G.sum())
+        # print("SANITY", (GG-G).abs())
+        # exit(0)
 
         ######################################################################
         # The "flashbacks"