Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 19 Jul 2023 15:51:03 +0000 (17:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 19 Jul 2023 15:51:03 +0000 (17:51 +0200)
tasks.py

index 0f44760..0a4dd6f 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1042,7 +1042,7 @@ class RPL(Task):
                 )
             ],
             0,
-        ).to(self.device)
+        )
 
     def seq2str(self, seq):
         return " ".join([self.id2token[i] for i in seq])
@@ -1101,7 +1101,7 @@ class RPL(Task):
         self.test_input = self.tensorize(test_sequences)
 
         if logger is not None:
-            for x in self.train_input[:10]:
+            for x in self.train_input[:25]:
                 end = (x != self.t_nul).nonzero().max().item() + 1
                 seq = [self.id2token[i.item()] for i in x[:end]]
                 s = " ".join(seq)
@@ -1120,7 +1120,7 @@ class RPL(Task):
             input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
-            batch = batch[:, :last]
+            batch = batch[:, :last].to(self.device)
             yield batch
 
     def vocabulary_size(self):
@@ -1129,6 +1129,7 @@ class RPL(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis
     ):
+        # --------------------------------------------------------------------
         def compute_nb_errors(input, nb_to_log=0):
             result = input.clone()
             s = (result == self.t_prog).long()
@@ -1169,8 +1170,10 @@ class RPL(Task):
 
             return sum_nb_total, sum_nb_errors
 
+        # --------------------------------------------------------------------
+
         test_nb_total, test_nb_errors = compute_nb_errors(
-            self.test_input[:1000], nb_to_log=10
+            self.test_input[:1000].to(self.device), nb_to_log=10
         )
 
         logger(