Update.
[mygptrnn.git] / mygpt.py
index c061eb4..0e94672 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -514,7 +514,7 @@ class Caterpillar(nn.Module):
         T = bs.x.size(1)
         DV = self.w_V.size(1)
         DK = self.w_K.size(1)
-        Dout = self.w_O.size(1)
+        DM = self.w_O.size(1)
         CH = self.caterpillar_height
         CL = self.caterpillar_length
 
@@ -522,6 +522,8 @@ class Caterpillar(nn.Module):
             t0 >= CL and (t1 - t0) % CL == 0
         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
 
+        # We cache values to deal efficiently with auto-regression
+
         if bs.init_cache:
             self.rec_V = X.new_zeros(N, CH, T, DV)
             self.rec_K = X.new_zeros(N, CH, T, DK)
@@ -530,7 +532,7 @@ class Caterpillar(nn.Module):
             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
 
-            self.cache_Y = X.new_zeros(N, T, Dout)
+            self.cache_Y = X.new_zeros(N, T, DM)
 
         ######################################################################
         # Compute the recurrent state
@@ -563,7 +565,7 @@ class Caterpillar(nn.Module):
         # by updating that at time t-L, the parallel scan operates
         # with a period of L. To do so we split the time indexing in
         # two axes, the second of size CL, and run the parallel scan
-        # using the other alone as the sequence index.
+        # using the other as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))