Update
[beaver.git] / beaver.py
index 4d4f98d..517f29a 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -26,9 +26,7 @@ else:
 
 ######################################################################
 
-parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
-)
+parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
 
 parser.add_argument("--log_filename", type=str, default="train.log")
 
@@ -38,9 +36,11 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--nb_epochs", type=int, default=25)
 
-parser.add_argument("--batch_size", type=int, default=100)
+parser.add_argument("--nb_train_samples", type=int, default=200000)
+
+parser.add_argument("--nb_test_samples", type=int, default=50000)
 
-parser.add_argument("--data_size", type=int, default=-1)
+parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--optim", type=str, default="adam")
 
@@ -73,11 +73,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # maze options
 
-parser.add_argument("--world_height", type=int, default=13)
+parser.add_argument("--maze_height", type=int, default=13)
 
-parser.add_argument("--world_width", type=int, default=21)
+parser.add_argument("--maze_width", type=int, default=21)
 
-parser.add_argument("--world_nb_walls", type=int, default=15)
+parser.add_argument("--maze_nb_walls", type=int, default=15)
 
 ######################################################################
 
@@ -129,9 +129,8 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask):
     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
-            model(
-                mygpt.BracketedSequence(input, 0, i.min())
-            )  # Needed to initialize the model's cache
+            # Needed to initialize the model's cache
+            model(mygpt.BracketedSequence(input, 0, i.min()))
         for s in range(i.min(), i.max() + 1):
             output = model(mygpt.BracketedSequence(input, s, 1)).x
             logits = output[:, s]
@@ -170,16 +169,23 @@ class TaskMaze(Task):
         s = s.reshape(s.size(0), -1, self.height, self.width)
         return (s[:, k] for k in range(s.size(1)))
 
-    def __init__(self, batch_size, height, width, nb_walls, device=torch.device("cpu")):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        nb_walls,
+        device=torch.device("cpu"),
+    ):
         self.batch_size = batch_size
         self.height = height
         self.width = width
         self.device = device
 
-        nb = args.data_size if args.data_size > 0 else 250000
-
         mazes_train, paths_train = maze.create_maze_data(
-            (4 * nb) // 5,
+            nb_train_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
@@ -187,10 +193,9 @@ class TaskMaze(Task):
         )
         mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
         self.train_input = self.map2seq(mazes_train, paths_train)
-        self.nb_codes = self.train_input.max() + 1
 
         mazes_test, paths_test = maze.create_maze_data(
-            nb // 5,
+            nb_test_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
@@ -199,9 +204,13 @@ class TaskMaze(Task):
         mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
         self.test_input = self.map2seq(mazes_test, paths_test)
 
-    def batches(self, split="train"):
+        self.nb_codes = self.train_input.max() + 1
+
+    def batches(self, split="train", nb_to_use=-1):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
         for batch in tqdm.tqdm(
             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
         ):
@@ -210,12 +219,13 @@ class TaskMaze(Task):
     def vocabulary_size(self):
         return self.nb_codes
 
-    def compute_error(self, model, split="train"):
+    def compute_error(self, model, split="train", nb_to_use=-1):
         nb_total, nb_correct = 0, 0
-        for input in task.batches(split):
+        for input in task.batches(split, nb_to_use):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
@@ -224,26 +234,42 @@ class TaskMaze(Task):
         return nb_total, nb_correct
 
     def produce_results(self, n_epoch, model):
-        train_nb_total, train_nb_correct = self.compute_error(model, "train")
-        log_string(
-            f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-        )
-
-        test_nb_total, test_nb_correct = self.compute_error(model, "test")
-        log_string(
-            f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        input = self.test_input[:32]
-        result = input.clone()
-        ar_mask = result.new_zeros(result.size())
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            train_nb_total, train_nb_correct = self.compute_error(
+                model, "train", nb_to_use=1000
+            )
+            log_string(
+                f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            )
+
+            test_nb_total, test_nb_correct = self.compute_error(
+                model, "test", nb_to_use=1000
+            )
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            input = self.test_input[:32]
+            result = input.clone()
+            ar_mask = result.new_zeros(result.size())
+            ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
+            masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
 
-        ar_mask[:, self.height * self.width :] = 1
-        masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+            mazes, paths = self.seq2map(input)
+            _, predicted_paths = self.seq2map(result)
+            maze.save_image(
+                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+                mazes,
+                paths,
+                predicted_paths,
+                maze.path_correctness(mazes, predicted_paths),
+            )
 
-        mazes, paths = self.seq2map(input)
-        _, predicted_paths = self.seq2map(result)
-        maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+            model.train(t)
 
 
 ######################################################################
@@ -252,10 +278,12 @@ log_string(f"device {device}")
 
 
 task = TaskMaze(
+    nb_train_samples=args.nb_train_samples,
+    nb_test_samples=args.nb_test_samples,
     batch_size=args.batch_size,
-    height=args.world_height,
-    width=args.world_width,
-    nb_walls=args.world_nb_walls,
+    height=args.maze_height,
+    width=args.maze_width,
+    nb_walls=args.maze_nb_walls,
     device=device,
 )
 
@@ -390,9 +418,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)