projects
/
beaver.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update
[beaver.git]
/
beaver.py
diff --git
a/beaver.py
b/beaver.py
index
6a6343d
..
4f41832
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-66,6
+66,8
@@
parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
parser.add_argument("--random_regression_order", action="store_true", default=False)
parser.add_argument("--random_regression_order", action="store_true", default=False)
+parser.add_argument("--noncausal_prompt", action="store_true", default=False)
+
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--overwrite_results", action="store_true", default=False)
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--overwrite_results", action="store_true", default=False)
@@
-517,6
+519,14
@@
log_string(f"vocabulary_size {vocabulary_size}")
##############################
##############################
+amm_generator = None
+
+if args.noncausal_prompt:
+ amm_generator = lambda d: torch.logical_and(
+ torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
+ torch.arange(d)[None, None, :, None] >= d // 2,
+ )
+
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
@@
-526,6
+536,7
@@
model = mygpt.MyGPT(
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
+ amm_generator=amm_generator,
)
model.to(device)
)
model.to(device)