projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c3621f9
)
Added the rng state in the checkpoint.
author
Francois Fleuret
<francois@fleuret.org>
Sun, 7 Aug 2022 19:50:36 +0000
(21:50 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sun, 7 Aug 2022 19:50:36 +0000
(21:50 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
b01ea0a
..
d4a8cfb
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-430,17
+430,6
@@
log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
######################################################################
######################################################################
-if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
- raise ValueError(f'Unknown optimizer {args.optim}.')
-
-######################################################################
-
nb_epochs_finished = 0
if args.no_checkpoint:
nb_epochs_finished = 0
if args.no_checkpoint:
@@
-448,10
+437,12
@@
if args.no_checkpoint:
else:
try:
else:
try:
- checkpoint = torch.load(args.checkpoint_name
, map_location = device
)
+ checkpoint = torch.load(args.checkpoint_name)
nb_epochs_finished = checkpoint['nb_epochs_finished']
model.load_state_dict(checkpoint['model_state'])
nb_epochs_finished = checkpoint['nb_epochs_finished']
model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
+ torch.set_rng_state(checkpoint['rng_state'])
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
except FileNotFoundError:
log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
except FileNotFoundError:
@@
-472,7
+463,16
@@
token_probas = token_count / token_count.sum()
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
-for k in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+ if args.optim == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+ elif args.optim == 'adam':
+ optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+ elif args.optim == 'adamw':
+ optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+ else:
+ raise ValueError(f'Unknown optimizer {args.optim}.')
model.train()
model.train()
@@
-505,16
+505,19
@@
for k in range(nb_epochs_finished, nb_epochs):
train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
- log_string(f'perplexity {
k
} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+ log_string(f'perplexity {
n_epoch
} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
- task.produce_results(
k
, model)
+ task.produce_results(
n_epoch
, model)
checkpoint = {
checkpoint = {
- 'nb_epochs_finished':
k
+ 1,
+ 'nb_epochs_finished':
n_epoch
+ 1,
'model_state': model.state_dict(),
'model_state': model.state_dict(),
- '
optimizer_state': optimizer.state_dict()
+ '
rng_state': torch.get_rng_state(),
}
}
+ if torch.cuda.is_available():
+ checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
+
torch.save(checkpoint, args.checkpoint_name)
######################################################################
torch.save(checkpoint, args.checkpoint_name)
######################################################################