X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=4f41832fcd928b7f25bc9722a66b2d810229711e;hb=eb86c22964f03f186ee225f129bc260128b10f9a;hp=6a6343d4fc3e303f3352bfaaf7a29a5eabf020fc;hpb=88bcf05864ddf89d071ee4be17af57b3b3ce7c2a;p=beaver.git diff --git a/beaver.py b/beaver.py index 6a6343d..4f41832 100755 --- a/beaver.py +++ b/beaver.py @@ -66,6 +66,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--random_regression_order", action="store_true", default=False) +parser.add_argument("--noncausal_prompt", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -517,6 +519,14 @@ log_string(f"vocabulary_size {vocabulary_size}") ############################## +amm_generator = None + +if args.noncausal_prompt: + amm_generator = lambda d: torch.logical_and( + torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :], + torch.arange(d)[None, None, :, None] >= d // 2, + ) + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -526,6 +536,7 @@ model = mygpt.MyGPT( nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, + amm_generator=amm_generator, ) model.to(device)