def generation_order(x, fixed_len):
if args.random_regression_order:
order = torch.rand(x.size(), device=x.device)
- order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=x.device)
order = order.sort(1).indices
else:
order = (
def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
- for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+ for input, ar_mask, order in zip(
+ input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size)
+ ):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
# Needed to initialize the model's cache
######################################################################
-def compute_perplexity(model, fixed_len, split="train"):
+def compute_perplexity(model, task, fixed_len, split="train"):
with torch.autograd.no_grad():
t = model.training
model.eval()
scores = scores.reshape(-1, task.height, task.width)
mazes = mazes.reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
+ filename = (
+ f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
+ )
maze.save_image(
- os.path.join(
- args.result_dir,
- f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
- ),
+ os.path.join(args.result_dir, filename),
mazes=mazes,
score_paths=scores,
score_truth=targets,
)
+ log_string(f"wrote {filename}")
+
# -------------------
gpt.train(t)
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result *= 1 - ar_mask
- masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+ x, order = shuffle(result, self.height * self.width)
+ masked_inplace_autoregression(
+ model, self.batch_size, x, ar_mask, order=order
+ )
+ result = reorder(x, order, back=True)
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
+ filename = f"result_{n_epoch:04d}.png"
maze.save_image(
- os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+ os.path.join(args.result_dir, filename),
mazes=mazes,
target_paths=paths,
predicted_paths=predicted_paths,
path_correct=maze.path_correctness(mazes, predicted_paths),
)
+ log_string(f"wrote {filename}")
model.train(t)
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="train"
+ model, task, fixed_len=task.height * task.width, split="train"
)
test_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="test"
+ model, task, fixed_len=task.height * task.width, split="test"
)
log_string(
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
test_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="test"
+ model, task, fixed_len=task.height * task.width, split="test"
)
log_string(