Update
[beaver.git] / beaver.py
index 4f41832..5ee468e 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -524,7 +524,10 @@ amm_generator = None
 if args.noncausal_prompt:
     amm_generator = lambda d: torch.logical_and(
         torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
-        torch.arange(d)[None, None, :, None] >= d // 2,
+        torch.logical_or(
+            torch.arange(d)[None, None, :, None] >= d // 2,
+            torch.arange(d)[None, None, None, :] >= d // 2,
+        ),
     )
 
 model = mygpt.MyGPT(