projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
7b2d37d
)
Update.
author
François Fleuret
<francois@fleuret.org>
Thu, 8 Feb 2024 12:28:27 +0000
(13:28 +0100)
committer
François Fleuret
<francois@fleuret.org>
Thu, 8 Feb 2024 12:28:27 +0000
(13:28 +0100)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
91c885b
..
d6845e8
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-87,7
+87,9
@@
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--attention", type=str, default=None)
parser.add_argument("--attention", type=str, default=None)
-parser.add_argument("--proba_memex", type=float, default=0)
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=1)
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dim_model", type=int, default=None)
@@
-738,7
+740,7
@@
log_string(f"device {device}")
vocabulary_size = task.vocabulary_size()
vocabulary_size = task.vocabulary_size()
-if args.
proba_memex
> 0:
+if args.
memex_proba
> 0:
vocabulary_size += 1
log_string(f"vocabulary_size {vocabulary_size}")
vocabulary_size += 1
log_string(f"vocabulary_size {vocabulary_size}")
@@
-902,13
+904,11
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- def add_memex(batches,
proba_memex
):
+ def add_memex(batches,
memex_proba
):
for input in batches:
for input in batches:
- if torch.rand(1).item() < proba_memex:
- sep = (
- torch.full(
- (input.size(0), 1), vocabulary_size - 1, device=input.device
- ),
+ if torch.rand(1).item() < memex_proba:
+ sep = torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
)
yield torch.cat(
)
yield torch.cat(
@@
-921,7
+921,10
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
yield input
)
yield input
- train_batches = add_memex(task.batches(split="train"), args.proba_memex)
+ train_batches = add_memex(
+ task.batches(split="train"),
+ args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+ )
for input in train_batches:
model.reset_inner_loss()
for input in train_batches:
model.reset_inner_loss()