Update.
[picoclvr.git] / picoclvr.py
index 5da3943..0cd3062 100755 (executable)
@@ -5,6 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import math
 import torch, torchvision
 import torch.nn.functional as F
 
@@ -201,7 +202,12 @@ def generate(
     descr = []
 
     for n in range(nb):
-        nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+        # we want uniform over the combinations of 1 to max_nb_squares
+        # pixels of nb_colors
+        logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        nb_squares = dist.sample((1,)) + 1
+        # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
         square_position = torch.randperm(height * width)[:nb_squares]
 
         # color 0 is white and reserved for the background