Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 07:43:27 +0000 (08:43 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 07:43:27 +0000 (08:43 +0100)
mygpt.py

index 17f2f6d..ed4b2a7 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -540,6 +540,9 @@ class Caterpillar(nn.Module):
 
             self.cache_Y = X.new_zeros(N, T, DM)
 
+        V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
+        K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+
         ######################################################################
         # Compute the recurrent state
 
@@ -558,24 +561,21 @@ class Caterpillar(nn.Module):
 
         G = G / G.sum(1, keepdim=True).clamp(min=1)
 
-        if self.training and self.proba_gate_dropout > 0.0:
-            warnings.warn("gate dropout", RuntimeWarning)
-            epsilon = 0.5
-
-        V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
-        K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
-
         # We prepare the arguments for the parallel scan
 
         A = 1 - G.sum(1)
         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
 
-        # Initial recurrent state
+        # We start from cached values, which matters in inference
 
         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
 
+        if self.training and self.proba_gate_dropout > 0.0:
+            warnings.warn("gate dropout", RuntimeWarning)
+            epsilon = 0.5
+
         #################################################################
         # Associative scan