Finalized PicoCLVR with "many colors".
authorFrancois Fleuret <francois@fleuret.org>
Mon, 20 Jun 2022 06:14:46 +0000 (08:14 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 20 Jun 2022 06:14:46 +0000 (08:14 +0200)
main.py
picoclvr.py

diff --git a/main.py b/main.py
index a31284e..3bf7587 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -111,12 +111,20 @@ import picoclvr
 
 class TaskPicoCLVR(Task):
 
-    def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
+    def __init__(self, batch_size,
+                 height = 6, width = 8, many_colors = False,
+                 device = torch.device('cpu')):
+
         self.batch_size = batch_size
         self.device = device
         nb = args.data_size if args.data_size > 0 else 250000
 
-        descr = picoclvr.generate(nb, height = height, width = width)
+        descr = picoclvr.generate(
+            nb,
+            height = height, width = width,
+            many_colors = many_colors
+        )
+
         descr = [ s.strip().split(' ') for s in descr ]
         l = max([ len(s) for s in descr ])
         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
index f4d7a65..712da17 100755 (executable)
@@ -71,7 +71,9 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
 
 ######################################################################
 
-def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False):
+def generate(nb, height = 6, width = 8,
+             max_nb_squares = 5, max_nb_statements = 10,
+             many_colors = False):
 
     nb_colors =  len(color_tokens) - 1 if many_colors else max_nb_squares