Update master
authorFrançois Fleuret <francois@fleuret.org>
Fri, 7 Apr 2023 21:35:12 +0000 (23:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 7 Apr 2023 21:35:12 +0000 (23:35 +0200)
beaver.py

index e69f151..5abe39b 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -238,9 +238,9 @@ def oneshot_trace_loss(mazes, output, policies, height, width):
 def oneshot(model, learning_rate_scheduler, task):
     t = model.training
     model.eval()
-    mazes = task.test_input[:32].clone()
+    mazes = task.test_input[:48].clone()
     mazes[:, task.height * task.width :] = 0
-    policies = task.test_policies[:32]
+    policies = task.test_policies[:48]
     targets = maze.stationary_densities(
         mazes[:, : task.height * task.width].view(-1, task.height, task.width),
         policies.view(-1, 4, task.height, task.width),
@@ -253,7 +253,7 @@ def oneshot(model, learning_rate_scheduler, task):
     )
     mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
     targets = targets.reshape(-1, task.height, task.width)
-    paths = task.test_input[:32, task.height * task.width :].reshape(
+    paths = task.test_input[:48, task.height * task.width :].reshape(
         -1, task.height, task.width
     )
     filename = f"oneshot.png"
@@ -335,8 +335,8 @@ def oneshot_old(gpt, learning_rate_scheduler, task):
         )
 
         # -------------------
-        mazes = task.test_input[:32, : task.height * task.width]
-        policies = task.test_policies[:32]
+        mazes = task.test_input[:48, : task.height * task.width]
+        policies = task.test_policies[:48]
         output_gpt = eval_mygpt(
             gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
         )
@@ -579,7 +579,7 @@ class TaskMaze(Task):
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
-            input = self.test_input[:32]
+            input = self.test_input[:48]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1