######################################################################
-parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
parser.add_argument("--log_filename", type=str, default="train.log")
for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
- model(
- mygpt.BracketedSequence(input, 0, i.min())
- ) # Needed to initialize the model's cache
+ # Needed to initialize the model's cache
+ model(mygpt.BracketedSequence(input, 0, i.min()))
for s in range(i.min(), i.max() + 1):
output = model(mygpt.BracketedSequence(input, s, 1)).x
logits = output[:, s]
)
mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
self.train_input = self.map2seq(mazes_train, paths_train)
- self.nb_codes = self.train_input.max() + 1
mazes_test, paths_test = maze.create_maze_data(
nb_test_samples,
mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
self.test_input = self.map2seq(mazes_test, paths_test)
+ self.nb_codes = self.train_input.max() + 1
+
def batches(self, split="train", nb_to_use=-1):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
- result *= 1-ar_mask
+ result *= 1 - ar_mask
masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
mazes, paths = self.seq2map(result)
nb_correct += maze.path_correctness(mazes, paths).long().sum()
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
- result *= 1-ar_mask
+ result *= 1 - ar_mask
masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
maze.save_image(
- f"result_{n_epoch:04d}.png",
+ os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
mazes,
paths,
predicted_paths,
for input in task.batches(split="test"):
input = input.to(device)
- # input, loss_masks, true_images = task.excise_last_image(input)
- # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)