X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=f6934b78cd497f1f6f1ed47c1460ba20005fb333;hb=13a6ecc6e00a75ce5a95c54c11ce6f60902f57f1;hp=d4a8cfb06c846e60f7f950d3cc8dadf2b9a7ce1b;hpb=3b62d298013c7b940aec7cab0f74fb5118493f99;p=mygpt.git diff --git a/main.py b/main.py index d4a8cfb..f6934b7 100755 --- a/main.py +++ b/main.py @@ -42,7 +42,10 @@ parser.add_argument('--optim', type = str, default = 'adam') parser.add_argument('--learning_rate', - type = float, default = 1e-4) + type = float, default = 1e-3) + +parser.add_argument('--learning_rate_end', + type = float, default = 1e-6) parser.add_argument('--dim_model', type = int, default = 512) @@ -465,12 +468,20 @@ train_set_perplexity = math.exp(entropy) for n_epoch in range(nb_epochs_finished, nb_epochs): + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (nb_epochs - 1) + lr = math.exp((1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end)) + log_string(f'learning_rate {lr}') + if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.SGD(model.parameters(), lr = lr) elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr = lr) elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr = lr) else: raise ValueError(f'Unknown optimizer {args.optim}.')