projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
3b62d29
)
Added args.learning_rate_end for an exponential decay.
author
Francois Fleuret
<francois@fleuret.org>
Mon, 8 Aug 2022 15:59:08 +0000
(17:59 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 8 Aug 2022 15:59:08 +0000
(17:59 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
d4a8cfb
..
f6934b7
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-42,7
+42,10
@@
parser.add_argument('--optim',
type = str, default = 'adam')
parser.add_argument('--learning_rate',
type = str, default = 'adam')
parser.add_argument('--learning_rate',
- type = float, default = 1e-4)
+ type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+ type = float, default = 1e-6)
parser.add_argument('--dim_model',
type = int, default = 512)
parser.add_argument('--dim_model',
type = int, default = 512)
@@
-465,12
+468,20
@@
train_set_perplexity = math.exp(entropy)
for n_epoch in range(nb_epochs_finished, nb_epochs):
for n_epoch in range(nb_epochs_finished, nb_epochs):
+ if args.learning_rate_end < 0:
+ lr = args.learning_rate
+ else:
+ u = n_epoch / (nb_epochs - 1)
+ lr = math.exp((1 - u) * math.log(args.learning_rate) +
+ u * math.log(args.learning_rate_end))
+ log_string(f'learning_rate {lr}')
+
if args.optim == 'sgd':
if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr =
args.learning_rate
)
+ optimizer = torch.optim.SGD(model.parameters(), lr =
lr
)
elif args.optim == 'adam':
elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr =
args.learning_rate
)
+ optimizer = torch.optim.Adam(model.parameters(), lr =
lr
)
elif args.optim == 'adamw':
elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr =
args.learning_rate
)
+ optimizer = torch.optim.AdamW(model.parameters(), lr =
lr
)
else:
raise ValueError(f'Unknown optimizer {args.optim}.')
else:
raise ValueError(f'Unknown optimizer {args.optim}.')