Update.
[picoclvr.git] / tasks.py
index 0ab1823..cbc8e6b 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1459,8 +1459,7 @@ class Grid(Task):
         nb_train_samples,
         nb_test_samples,
         batch_size,
-        height,
-        width,
+        size,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1468,7 +1467,7 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(height=height, width=width)
+        self.grid_factory = grid.GridFactory(size=size)
 
         if logger is not None:
             logger(
@@ -1496,8 +1495,8 @@ class Grid(Task):
         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
         self.t_nul = self.token2id["#"]
-        self.t_true = self.token2id["<true>"]
-        self.t_false = self.token2id["<false>"]
+        self.t_true = self.token2id["true"]
+        self.t_false = self.token2id["false"]
 
         # Tokenize the train and test sets
         self.train_input = self.tensorize(self.train_descr)
@@ -1540,8 +1539,8 @@ class Grid(Task):
         nb_total = ar_mask.sum().item()
         nb_correct = ((correct == result).long() * ar_mask).sum().item()
 
-        logger(f"test_performance {nb_total=} {nb_correct=}")
-        logger(f"main_test_accuracy {nb_correct / nb_total}")
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
 
 
 ######################################################################