type = str, default = 'adam')
parser.add_argument('--learning_rate',
- type = float, default = 1e-4)
+ type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+ type = float, default = 1e-6)
parser.add_argument('--dim_model',
type = int, default = 512)
parser.add_argument('--dropout',
type = float, default = 0.1)
-parser.add_argument('--synthesis_sampling',
- action='store_true', default = True)
+parser.add_argument('--deterministic_synthesis',
+ action='store_true', default = False)
parser.add_argument('--no_checkpoint',
action='store_true', default = False)
for s in range(first, input.size(1)):
output = model(input)
logits = output[:, s]
- if args.synthesis_sampling:
+ if args.deterministic_synthesis:
+ t_next = logits.argmax(1)
+ else:
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
- else:
- t_next = logits.argmax(1)
input[:, s] = t_next
return results
self.device = device
nb = args.data_size if args.data_size > 0 else 250000
+ log_string(f'generating {nb} samples (can take some time)')
self.train_descr = generate_descr((nb * 4) // 5)
self.test_descr = generate_descr((nb * 1) // 5)
input = F.pad(input, (0, 1)) # Add the next token, the one to predict
output = model(input)
logits = output[0, -1]
- if args.synthesis_sampling:
+ if args.deterministic_synthesis:
+ t_next = logits.argmax()
+ else:
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
- else:
- t_next = logits.argmax()
t_generated.append(self.vocab.lookup_token(t_next))
if t_generated[-1] == '<nul>': break
######################################################################
-if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
- raise ValueError(f'Unknown optimizer {args.optim}.')
-
-######################################################################
-
nb_epochs_finished = 0
if args.no_checkpoint:
else:
try:
- checkpoint = torch.load(args.checkpoint_name, map_location = device)
+ checkpoint = torch.load(args.checkpoint_name)
nb_epochs_finished = checkpoint['nb_epochs_finished']
model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
+ torch.set_rng_state(checkpoint['rng_state'])
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
except FileNotFoundError:
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
-for k in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+ if args.learning_rate_end < 0:
+ lr = args.learning_rate
+ else:
+ u = n_epoch / (nb_epochs - 1)
+ lr = math.exp((1 - u) * math.log(args.learning_rate) +
+ u * math.log(args.learning_rate_end))
+ log_string(f'learning_rate {lr}')
+
+ if args.optim == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(), lr = lr)
+ elif args.optim == 'adam':
+ optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+ elif args.optim == 'adamw':
+ optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
+ else:
+ raise ValueError(f'Unknown optimizer {args.optim}.')
model.train()
train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
- log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+ log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
- task.produce_results(k, model)
+ task.produce_results(n_epoch, model)
checkpoint = {
- 'nb_epochs_finished': k + 1,
+ 'nb_epochs_finished': n_epoch + 1,
'model_state': model.state_dict(),
- 'optimizer_state': optimizer.state_dict()
+ 'rng_state': torch.get_rng_state(),
}
+ if torch.cuda.is_available():
+ checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
+
torch.save(checkpoint, args.checkpoint_name)
######################################################################