Update.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 16:15:00 +0000 (18:15 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 16:15:00 +0000 (18:15 +0200)
main.py

diff --git a/main.py b/main.py
index b579177..77c4b9e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -130,36 +130,39 @@ class TaskPicoCLVR(Task):
                  height, width, many_colors = False,
                  device = torch.device('cpu')):
 
+        def generate_descr(nb):
+            descr = picoclvr.generate(
+                nb,
+                height = self.height, width = self.width,
+                many_colors = many_colors
+            )
+
+            descr = [ s.strip().split(' ') for s in descr ]
+            l = max([ len(s) for s in descr ])
+            descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
+
+            return descr
+
         self.height = height
         self.width = width
         self.batch_size = batch_size
         self.device = device
         nb = args.data_size if args.data_size > 0 else 250000
 
-        descr = picoclvr.generate(
-            nb,
-            height = self.height, width = self.width,
-            many_colors = many_colors
-        )
-
-        # self.test_descr = descr[:nb // 5]
-        # self.train_descr = descr[nb // 5:]
-
-        descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in descr ])
-        descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
+        self.train_descr = generate_descr((nb * 4) // 5)
+        self.test_descr = generate_descr((nb * 1) // 5)
 
         tokens = set()
-        for s in descr:
-            for t in s: tokens.add(t)
+        for d in [ self.train_descr, self.test_descr ]:
+            for s in d:
+                for t in s: tokens.add(t)
         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
 
-        t = [ [ self.token2id[u] for u in s ] for s in descr ]
-        data_input = torch.tensor(t, device = self.device)
-
-        self.test_input = data_input[:nb // 5]
-        self.train_input = data_input[nb // 5:]
+        t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
+        self.train_input = torch.tensor(t, device = self.device)
+        t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
+        self.test_input = torch.tensor(t, device = self.device)
 
     def batches(self, split = 'train'):
         assert split in { 'train', 'test' }