Update.
[picoclvr.git] / tasks.py
index 1ea3b5d..e5d3a7e 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -117,7 +117,8 @@ class TaskFromFile(Task):
 
     def __init__(
         self,
-        filename,
+        train_filename,
+        test_filename,
         nb_train_samples,
         nb_test_samples,
         batch_size,
@@ -126,26 +127,37 @@ class TaskFromFile(Task):
         self.batch_size = batch_size
         self.device = device
 
-        pairs = []
-        with open(filename, "r") as f:
-            for _ in range(nb_train_samples + nb_test_samples):
-                sequence = f.readline().strip()
-                pred_mask = f.readline().strip()
-                assert len(sequence) == len(pred_mask)
-                assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
-                pairs.append((sequence, pred_mask))
-
-        symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+        def read_file(filename, nb=-1):
+            pairs = []
+            with open(filename, "r") as f:
+                while True:
+                    sequence = f.readline().strip()
+                    if not sequence:
+                        break
+                    pred_mask = f.readline().strip()
+                    assert len(sequence) == len(pred_mask)
+                    assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
+                    pairs.append((sequence, pred_mask))
+                    if len(pairs) == nb:
+                        break
+
+            if nb > 0:
+                pairs = pairs[:nb]
+                assert len(pairs) == nb
+
+            return pairs
+
+        train_pairs = read_file(train_filename, nb_train_samples)
+        test_pairs = read_file(test_filename, nb_test_samples)
+
+        symbols = ["#"] + list(
+            set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
+        )
         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
 
-        self.train_input, self.train_pred_masks = self.tensorize(
-            pairs[:nb_train_samples]
-        )
-        self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
-
-        assert self.train_input.size(0) == nb_train_samples
-        assert self.test_input.size(0) == nb_test_samples
+        self.train_input, self.train_pred_masks = self.tensorize(train_pairs)
+        self.test_input, self.test_pred_masks = self.tensorize(test_pairs)
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -176,7 +188,7 @@ class TaskFromFile(Task):
 
         logger(f"----------------------------------------------------------")
 
-        for e in self.tensor2str(result[:10]):
+        for e in self.tensor2str(result[:50]):
             logger(f"test_before {e}")
 
         masked_inplace_autoregression(
@@ -190,7 +202,7 @@ class TaskFromFile(Task):
 
         logger(f"----------------------------------------------------------")
 
-        for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
+        for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
             logger(f"test_after  {e}")
             logger(f"correct     {c}")