Update.
[mygpt.git] / picoclvr.py
index 3ecbf3a..059e352 100755 (executable)
@@ -95,7 +95,8 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
 
 def generate(nb, height, width,
              max_nb_squares = 5, max_nb_properties = 10,
-             nb_colors = 5):
+             nb_colors = 5,
+             pruning_criterion = None):
 
     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
 
@@ -117,6 +118,9 @@ def generate(nb, height, width,
 
         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
 
+        if pruning_criterion is not None:
+            s = list(filter(pruning_criterion,s))
+
         # pick at most max_nb_properties at random
 
         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
@@ -206,23 +210,26 @@ def nb_properties(descr, height, width):
 ######################################################################
 
 if __name__ == '__main__':
-    descr = generate(nb = 5)
+    descr = generate(
+        nb = 5, height = 12, width = 16,
+        pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+    )
 
-    #print(descr2properties(descr))
-    print(nb_properties(descr))
+    print(descr2properties(descr, height = 12, width = 16))
+    print(nb_properties(descr, height = 12, width = 16))
 
     with open('picoclvr_example.txt', 'w') as f:
         for d in descr:
             f.write(f'{d}\n\n')
 
-    img = descr2img(descr)
+    img = descr2img(descr, height = 12, width = 16)
     torchvision.utils.save_image(img / 255.,
                                  'picoclvr_example.png', nrow = 16, pad_value = 0.8)
 
     import time
 
     start_time = time.perf_counter()
-    descr = generate(nb = 1000)
+    descr = generate(nb = 1000, height = 12, width = 16)
     end_time = time.perf_counter()
     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')