Update.
[mygptrnn.git] / tasks.py
index 218ff36..57c6801 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -106,7 +106,7 @@ class SandBox(Task):
             device
         ), self.test_ar_mask.to(device)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
         # A bit of paranoia never hurts
         assert self.nb_codes <= max_nb_codes
@@ -579,7 +579,7 @@ class Maze(Task):
         )
         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -756,7 +756,7 @@ class Snake(Task):
             self.device,
         )
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -871,7 +871,7 @@ class Stack(Task):
         counts = F.one_hot(counts).sum(0)
         logger(f"test_pop_stack_counts {counts}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1078,7 +1078,7 @@ class RPL(Task):
                 s = " ".join(seq)
                 logger(f"example_seq {s}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1308,7 +1308,7 @@ class Expr(Task):
         self.train_input = self.tensorize(train_sequences)
         self.test_input = self.tensorize(test_sequences)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1639,7 +1639,7 @@ class QMLP(Task):
             for e in self.test_ref_test_errors:
                 f.write(f"{e}\n")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", desc=None):
         assert split in {"train", "test"}