Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jan 2024 13:11:45 +0000 (14:11 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jan 2024 13:11:45 +0000 (14:11 +0100)
mygpt.py

index 0414bb6..760a3c6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -629,15 +629,21 @@ class Caterpillar(nn.Module):
 
             warnings.warn("gate dropout", RuntimeWarning)
 
+            if self.gate_dropout_sync:
+                shape_kill = (N, 1, 1)
+            else:
+                shape_kill = (N, H, R)
+
             # Pick a point in each of the NxHxR timeline and set this
             # entry and the following to 1
             kill = (
-                torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+                torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+                == 0
             ).cumsum(dim=3)
 
             # Keep these mask for only some of the NxHxR
             kill = kill * (
-                torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
+                torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
             )
 
             # The coefficient to keep are the complementary