if args.noncausal_prompt:
amm_generator = lambda d: torch.logical_and(
torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
- torch.arange(d)[None, None, :, None] >= d // 2,
+ torch.logical_or(
+ torch.arange(d)[None, None, :, None] >= d // 2,
+ torch.arange(d)[None, None, None, :] >= d // 2,
+ ),
)
model = mygpt.MyGPT(