Added default configurations and reformated with black. master
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Jan 2023 15:32:53 +0000 (16:32 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Jan 2023 15:32:53 +0000 (16:32 +0100)
main.py
mygpt.py
picoclvr.py

diff --git a/main.py b/main.py
index aa1b517..f7d03cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -15,111 +15,138 @@ import mygpt
 
 ######################################################################
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
-parser = argparse.ArgumentParser(description = 'My own GPT.')
+parser = argparse.ArgumentParser(description="My own GPT.")
 
-parser.add_argument('--log_filename',
-                    type = str, default = 'train.log')
+parser.add_argument("--log_filename", type=str, default="train.log")
 
-parser.add_argument('--seed',
-                    type = int, default = 0)
+parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = -1)
+parser.add_argument("--nb_epochs", type=int, default=None)
 
-parser.add_argument('--batch_size',
-                    type = int, default = 25)
+parser.add_argument("--batch_size", type=int, default=25)
 
-parser.add_argument('--data',
-                    type = str, default = 'wiki103')
+parser.add_argument("--data", type=str, default="wiki103")
 
-parser.add_argument('--data_size',
-                    type = int, default = -1)
+parser.add_argument("--data_size", type=int, default=None)
 
-parser.add_argument('--optim',
-                    type = str, default = 'adam')
+parser.add_argument("--optim", type=str, default="adam")
 
-parser.add_argument('--learning_rate',
-                    type = float, default = 1e-3)
+parser.add_argument("--learning_rate", type=float, default=1e-3)
 
-parser.add_argument('--learning_rate_end',
-                    type = float, default = 1e-6)
+parser.add_argument("--learning_rate_end", type=float, default=1e-6)
 
-parser.add_argument('--dim_model',
-                    type = int, default = 512)
+parser.add_argument("--dim_model", type=int, default=None)
 
-parser.add_argument('--dim_keys',
-                    type = int, default = 64)
+parser.add_argument("--dim_keys", type=int, default=None)
 
-parser.add_argument('--dim_hidden',
-                    type = int, default = 2048)
+parser.add_argument("--dim_hidden", type=int, default=None)
 
-parser.add_argument('--nb_heads',
-                    type = int, default = 8)
+parser.add_argument("--nb_heads", type=int, default=None)
 
-parser.add_argument('--nb_blocks',
-                    type = int, default = 12)
+parser.add_argument("--nb_blocks", type=int, default=None)
 
-parser.add_argument('--dropout',
-                    type = float, default = 0.1)
+parser.add_argument("--dropout", type=float, default=0.1)
 
-parser.add_argument('--deterministic_synthesis',
-                    action='store_true', default = False)
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument('--no_checkpoint',
-                    action='store_true', default = False)
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument('--checkpoint_name',
-                    type = str, default = 'checkpoint.pth')
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
 ##############################
 # picoclvr options
 
-parser.add_argument('--picoclvr_nb_colors',
-                    type = int, default = 5)
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
 
-parser.add_argument('--picoclvr_height',
-                    type = int, default = 12)
+parser.add_argument("--picoclvr_height", type=int, default=12)
 
-parser.add_argument('--picoclvr_width',
-                    type = int, default = 16)
+parser.add_argument("--picoclvr_width", type=int, default=16)
 
 ######################################################################
 
 args = parser.parse_args()
 
-log_file = open(args.log_filename, 'w')
+log_file = open(args.log_filename, "w")
 
 if args.seed >= 0:
     torch.manual_seed(args.seed)
 
 ######################################################################
 
+
 def log_string(s):
-    t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
+    t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
 
     if log_file is not None:
-        log_file.write(t + s + '\n')
+        log_file.write(t + s + "\n")
         log_file.flush()
 
     print(t + s)
     sys.stdout.flush()
 
+
 for n in vars(args):
-    log_string(f'args.{n} {getattr(args, n)}')
+    log_string(f"args.{n} {getattr(args, n)}")
 
 ######################################################################
 
+default_args = {
+    "mnist": {
+        "nb_epochs": 10,
+        "dim_model": 64,
+        "dim_keys": 64,
+        "dim_hidden": 128,
+        "nb_heads": 4,
+        "nb_blocks": 6,
+    },
+    "mnist-debug": {
+        "nb_epochs": 2,
+        "data_size": 10000,
+        "dim_model": 8,
+        "dim_keys": 8,
+        "dim_hidden": 8,
+        "nb_heads": 2,
+        "nb_blocks": 4,
+    },
+    "wiki103": {
+        "nb_epochs": 25,
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_blocks": 12,
+    },
+    "picoclvr": {
+        "nb_epochs": 25,
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_blocks": 12,
+    },
+}
+
+if args.data in default_args:
+    for k, v in default_args[args.data].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
+
 def autoregression(
-        model, batch_size,
-        nb_samples, nb_tokens_to_generate, primer = None,
-        device = torch.device('cpu')
+    model,
+    batch_size,
+    nb_samples,
+    nb_tokens_to_generate,
+    primer=None,
+    device=torch.device("cpu"),
 ):
     results = torch.zeros(
-        nb_samples, nb_tokens_to_generate,
-        dtype = torch.int64, device = device
+        nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device
     )
 
     if primer is None:
@@ -135,16 +162,18 @@ def autoregression(
             if args.deterministic_synthesis:
                 t_next = logits.argmax(1)
             else:
-                dist = torch.distributions.categorical.Categorical(logits = logits)
+                dist = torch.distributions.categorical.Categorical(logits=logits)
                 t_next = dist.sample()
             input[:, s] = t_next
 
     return results
 
+
 ######################################################################
 
+
 class Task:
-    def batches(self, split = 'train'):
+    def batches(self, split="train"):
         pass
 
     def vocabulary_size(self):
@@ -153,130 +182,127 @@ class Task:
     def produce_results(self, n_epoch, model):
         pass
 
+
 ######################################################################
 
 import picoclvr
 
+
 class TaskPicoCLVR(Task):
 
     # Make a tensor from a list of strings
     def tensorize(self, descr):
-        token_descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in token_descr ])
-        #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
-        token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
-        id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
-        return torch.tensor(id_descr, device = self.device)
-
-    def trim(self, x, token = '<nul>'):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        padded_token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    def trim(self, x, token="<nul>"):
         n = self.token2id[token]
-        i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
+        i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
         return x[:, a:b]
 
-    def __init__(self, batch_size,
-                 height, width, nb_colors = 5,
-                 device = torch.device('cpu')):
-
+    def __init__(
+        self, batch_size, height, width, nb_colors=5, device=torch.device("cpu")
+    ):
         def generate_descr(nb):
             return picoclvr.generate(
-                nb,
-                height = self.height, width = self.width,
-                nb_colors = nb_colors
+                nb, height=self.height, width=self.width, nb_colors=nb_colors
             )
 
         self.height = height
         self.width = width
         self.batch_size = batch_size
         self.device = device
-        nb = args.data_size if args.data_size > 0 else 250000
+        nb = args.data_size if args.data_size is not None else 250000
 
-        log_string(f'generating {nb} samples (can take some time)')
+        log_string(f"generating {nb} samples (can take some time)")
         self.train_descr = generate_descr((nb * 4) // 5)
         self.test_descr = generate_descr((nb * 1) // 5)
 
         # Build the tokenizer
-        tokens = { '<nul>' }
-        for d in [ self.train_descr, self.test_descr ]:
+        tokens = {"<nul>"}
+        for d in [self.train_descr, self.test_descr]:
             for s in d:
-                for t in s.strip().split(' '): tokens.add(t)
-        self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
-        self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
 
         # Tokenize the train and test sets
         self.train_input = self.tensorize(self.train_descr)
         self.test_input = self.tensorize(self.test_descr)
 
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
-        input = self.train_input if split == 'train' else self.test_input
-        for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"):
             yield self.trim(batch)
 
     def vocabulary_size(self):
         return len(self.token2id)
 
-    def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
+    def test_model(
+        self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
+    ):
         nb_tokens_to_generate = self.height * self.width + 3
-        result_descr = [ ]
+        result_descr = []
 
         for primer_descr in primers_descr:
 
             results = autoregression(
                 model,
                 self.batch_size,
-                nb_samples = nb_per_primer,
-                nb_tokens_to_generate = nb_tokens_to_generate,
-                primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
-                device = self.device
+                nb_samples=nb_per_primer,
+                nb_tokens_to_generate=nb_tokens_to_generate,
+                primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
+                device=self.device,
             )
 
-            l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
+            l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
             result_descr += l
 
-        np = picoclvr.nb_properties(
-            result_descr,
-            height = self.height, width = self.width
-        )
+        np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
 
         nb_requested_properties, _, nb_missing_properties = zip(*np)
 
-        log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
+        log_string(
+            f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
+        )
 
-        np=torch.tensor(np)
-        count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+        np = torch.tensor(np)
+        count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
         for i in range(count.size(0)):
             for j in range(count.size(1)):
-                count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+                count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
 
         if generate_images:
             img = [
-                picoclvr.descr2img(d, height = self.height, width = self.width)
+                picoclvr.descr2img(d, height=self.height, width=self.width)
                 for d in result_descr
             ]
 
             img = torch.cat(img, 0)
-            image_name = f'result_picoclvr_{n_epoch:04d}.png'
+            image_name = f"result_picoclvr_{n_epoch:04d}.png"
             torchvision.utils.save_image(
-                img / 255.,
-                image_name, nrow = nb_per_primer, pad_value = 0.8
+                img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
             )
-            log_string(f'wrote {image_name}')
+            log_string(f"wrote {image_name}")
 
         return count
 
     def produce_results(self, n_epoch, model):
         primers_descr = [
-            'red above green <sep> green top <sep> blue right of red <img>',
-            'there is red <sep> there is yellow <sep> there is blue <img>',
-            'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
-            'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
+            "red above green <sep> green top <sep> blue right of red <img>",
+            "there is red <sep> there is yellow <sep> there is blue <img>",
+            "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
+            "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
         ]
 
         self.test_model(
-            n_epoch, model,
-            primers_descr,
-            nb_per_primer=8, generate_images=True
+            n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
         )
 
         # FAR TOO SLOW!!!
@@ -284,23 +310,30 @@ class TaskPicoCLVR(Task):
         # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
 
         # count=self.test_model(
-            # n_epoch, model,
-            # test_primers_descr,
-            # nb_per_primer=1, generate_images=False
+        # n_epoch, model,
+        # test_primers_descr,
+        # nb_per_primer=1, generate_images=False
         # )
 
         # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
-            # for i in range(count.size(0)):
-                # for j in range(count.size(1)):
-                    # f.write(f'{count[i,j]}')
-                    # f.write(" " if j<count.size(1)-1 else "\n")
+        # for i in range(count.size(0)):
+        # for j in range(count.size(1)):
+        # f.write(f'{count[i,j]}')
+        # f.write(" " if j<count.size(1)-1 else "\n")
+
 
 ######################################################################
 
-class TaskWiki103(Task):
 
-    def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
-                 device = torch.device('cpu')):
+class TaskWiki103(Task):
+    def __init__(
+        self,
+        batch_size,
+        len_min=10,
+        len_max=200,
+        min_freq=100,
+        device=torch.device("cpu"),
+    ):
 
         self.batch_size = batch_size
         self.len_min = len_min
@@ -308,112 +341,117 @@ class TaskWiki103(Task):
         self.min_freq = min_freq
         self.device = device
 
-        self.tokenizer = torchtext.data.get_tokenizer('basic_english')
-        train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
+        self.tokenizer = torchtext.data.get_tokenizer("basic_english")
+        train_iter = torchtext.datasets.WikiText103(split="train", root="./data/nlp/")
 
         # Mostly for debug
-        if args.data_size > 0:
+        if args.data_size is not None:
             train_iter = itertools.islice(train_iter, args.data_size)
 
         def yield_tokens():
-            for l in tqdm.tqdm(train_iter, desc = 'vocab'):
+            for l in tqdm.tqdm(train_iter, desc="vocab"):
                 yield self.tokenizer(l)
 
         self.vocab = torchtext.vocab.build_vocab_from_iterator(
-            yield_tokens(),
-            specials = [ '<unk>', '<nul>' ],
-            min_freq = self.min_freq
+            yield_tokens(), specials=["<unk>", "<nul>"], min_freq=self.min_freq
         )
 
-        self.vocab.set_default_index(self.vocab[ '<unk>' ])
+        self.vocab.set_default_index(self.vocab["<unk>"])
 
     # makes a tensor from a list of list of tokens
     def tensorize(self, s):
         a = max(len(x) for x in s)
-        return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
+        return torch.tensor([self.vocab(x + ["<nul>"] * (a - len(x))) for x in s])
 
     def yield_batches(self, ds):
-        s = [ ]
+        s = []
         for l in ds:
             q = self.tokenizer(l)
             if len(q) >= self.len_min and len(q) <= self.len_max:
-                s += [ q ]
+                s += [q]
                 if len(s) == self.batch_size:
                     yield self.tensorize(s)
-                    s = [ ]
+                    s = []
 
         if len(s) > 0:
             yield self.tensorize(s)
 
-    def batches(self, split = 'train'):
-        data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
+    def batches(self, split="train"):
+        data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/")
 
         # Mostly for debug
-        if args.data_size > 0:
+        if args.data_size is not None:
             data_iter = itertools.islice(data_iter, args.data_size)
 
-        return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
+        return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}"))
 
     def vocabulary_size(self):
         return len(self.vocab)
 
     def produce_results(self, n_epoch, model):
         nb_tokens = 50
-        file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
-        with open(file_name, 'w') as outfile:
-             for primer in [
-                     'the cat is hunting a',
-                     'paris is the capital',
-                     'cars are convenient',
-                     'the difference between men and women is',
-                     'the object was blue all over and green all over it was',
-                     'cherries are red and lemons are',
-                     'cherries are sweet and lemons are',
-                     'two plus three equals',
-                     'deep learning is',
-             ]:
-                 t_primer = self.tokenizer(primer)
-                 t_generated = [ ]
-
-                 for j in range(nb_tokens):
-
-                     input = self.tensorize([ t_primer + t_generated ]).to(self.device)
-                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
-                     output = model(input)
-                     logits = output[0, -1]
-                     if args.deterministic_synthesis:
-                         t_next = logits.argmax()
-                     else:
-                         dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t_next = dist.sample()
-                     t_generated.append(self.vocab.lookup_token(t_next))
-                     if t_generated[-1] == '<nul>': break
-
-                 s = ' '.join(t_generated)
-
-                 outfile.write(f'<{primer}> {s}\n')
-
-        log_string(f'wrote {file_name}')
+        file_name = f"result_wiki103_{n_epoch:04d}.txt"
+
+        with open(file_name, "w") as outfile:
+            for primer in [
+                "the cat is hunting a",
+                "paris is the capital",
+                "cars are convenient",
+                "the difference between men and women is",
+                "the object was blue all over and green all over it was",
+                "cherries are red and lemons are",
+                "cherries are sweet and lemons are",
+                "two plus three equals",
+                "deep learning is",
+            ]:
+                t_primer = self.tokenizer(primer)
+                t_generated = []
+
+                for j in range(nb_tokens):
+
+                    input = self.tensorize([t_primer + t_generated]).to(self.device)
+                    input = F.pad(
+                        input, (0, 1)
+                    )  # Add the next token, the one to predict
+                    output = model(input)
+                    logits = output[0, -1]
+                    if args.deterministic_synthesis:
+                        t_next = logits.argmax()
+                    else:
+                        dist = torch.distributions.categorical.Categorical(
+                            logits=logits
+                        )
+                        t_next = dist.sample()
+                    t_generated.append(self.vocab.lookup_token(t_next))
+                    if t_generated[-1] == "<nul>":
+                        break
+
+                s = " ".join(t_generated)
+
+                outfile.write(f"<{primer}> {s}\n")
+
+        log_string(f"wrote {file_name}")
+
 
 ######################################################################
 
-class TaskMNIST(Task):
 
-    def __init__(self, batch_size, device = torch.device('cpu')):
+class TaskMNIST(Task):
+    def __init__(self, batch_size, device=torch.device("cpu")):
         self.device = device
         self.batch_size = batch_size
 
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
         data_set = torchvision.datasets.MNIST(
-            root = './data', train = (split == 'train'),
-            download = True
+            root="./data", train=(split == "train"), download=True
         )
         data_input = data_set.data.view(-1, 28 * 28).long()
-        if args.data_size >= 0:
-            data_input = data_input[:args.data_size]
-        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
+        if args.data_size is not None:
+            data_input = data_input[: args.data_size]
+        for batch in tqdm.tqdm(
+            data_input.split(self.batch_size), desc=f"epoch-{split}"
+        ):
             yield batch
 
     def vocabulary_size(self):
@@ -421,108 +459,118 @@ class TaskMNIST(Task):
 
     def produce_results(self, n_epoch, model):
         nb_samples = 64
-        results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
-        image_name = f'result_mnist_{n_epoch:04d}.png'
-        torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
-                                     image_name, nrow = 16, pad_value = 0.8)
-        log_string(f'wrote {image_name}')
+        results = autoregression(
+            model, self.batch_size, nb_samples, 28 * 28, device=self.device
+        )
+        image_name = f"result_mnist_{n_epoch:04d}.png"
+        torchvision.utils.save_image(
+            1 - results.reshape(-1, 1, 28, 28) / 255.0,
+            image_name,
+            nrow=16,
+            pad_value=0.8,
+        )
+        log_string(f"wrote {image_name}")
+
 
 ######################################################################
 
-log_string(f'device {device}')
-
-if args.data == 'wiki103':
-    nb_epochs_default = 10
-    task = TaskWiki103(batch_size = args.batch_size, device = device)
-elif args.data == 'mnist':
-    nb_epochs_default = 25
-    task = TaskMNIST(batch_size = args.batch_size, device = device)
-elif args.data == 'picoclvr':
-    nb_epochs_default = 10
-    task = TaskPicoCLVR(batch_size = args.batch_size,
-                        height = args.picoclvr_height,
-                        width = args.picoclvr_width,
-                        nb_colors = args.picoclvr_nb_colors,
-                        device = device)
+log_string(f"device {device}")
+
+if args.data == "wiki103":
+    task = TaskWiki103(batch_size=args.batch_size, device=device)
+elif args.data in {"mnist", "mnist-debug"}:
+    task = TaskMNIST(batch_size=args.batch_size, device=device)
+elif args.data == "picoclvr":
+    task = TaskPicoCLVR(
+        batch_size=args.batch_size,
+        height=args.picoclvr_height,
+        width=args.picoclvr_width,
+        nb_colors=args.picoclvr_nb_colors,
+        device=device,
+    )
 else:
-    raise ValueError(f'Unknown dataset {args.data}.')
+    raise ValueError(f"Unknown dataset {args.data}.")
 
 vocabulary_size = task.vocabulary_size()
 
-log_string(f'vocabulary_size {vocabulary_size}')
+log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
 
 model = mygpt.MyGPT(
-    vocabulary_size = vocabulary_size,
-    dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
-    nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
+    vocabulary_size=vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_blocks=args.nb_blocks,
+    dropout=args.dropout,
 )
 
 model.to(device)
 
 nb_parameters = sum(p.numel() for p in model.parameters())
-log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
 nb_epochs_finished = 0
 
 if args.no_checkpoint:
-    log_string(f'not trying to load checkpoint.')
+    log_string(f"not trying to load checkpoint.")
 
 else:
     try:
         checkpoint = torch.load(args.checkpoint_name)
-        nb_epochs_finished = checkpoint['nb_epochs_finished']
-        model.load_state_dict(checkpoint['model_state'])
-        torch.set_rng_state(checkpoint['rng_state'])
+        nb_epochs_finished = checkpoint["nb_epochs_finished"]
+        model.load_state_dict(checkpoint["model_state"])
+        torch.set_rng_state(checkpoint["rng_state"])
         if torch.cuda.is_available():
-            torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
-        log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
+            torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+        log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
 
     except FileNotFoundError:
-        log_string('starting from scratch.')
+        log_string("starting from scratch.")
 
     except:
-        log_string('error when loading the checkpoint.')
+        log_string("error when loading the checkpoint.")
         exit(1)
 
 ######################################################################
 
-nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
-
 token_count = 0
-for input in task.batches(split = 'train'):
-    token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+for input in task.batches(split="train"):
+    token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
 token_probas = token_count / token_count.sum()
 entropy = -torch.xlogy(token_probas, token_probas).sum()
 train_set_perplexity = math.exp(entropy)
 
-for n_epoch in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     if args.learning_rate_end < 0:
         lr = args.learning_rate
     else:
-        u = n_epoch / (nb_epochs - 1)
-        lr = math.exp((1 - u) * math.log(args.learning_rate) +
-                      u * math.log(args.learning_rate_end))
-        log_string(f'learning_rate {lr}')
-
-    if args.optim == 'sgd':
-        optimizer = torch.optim.SGD(model.parameters(), lr = lr)
-    elif args.optim == 'adam':
-        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
-    elif args.optim == 'adamw':
-        optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
+        u = n_epoch / (args.nb_epochs - 1)
+        lr = math.exp(
+            (1 - u) * math.log(args.learning_rate)
+            + u * math.log(args.learning_rate_end)
+        )
+        log_string(f"learning_rate {lr}")
+
+    if args.optim == "sgd":
+        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+    elif args.optim == "adam":
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+    elif args.optim == "adamw":
+        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
     else:
-        raise ValueError(f'Unknown optimizer {args.optim}.')
+        raise ValueError(f"Unknown optimizer {args.optim}.")
 
     model.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in task.batches(split = 'train'):
+    for input in task.batches(split="train"):
         input = input.to(device)
         output = model(input)
         loss = F.cross_entropy(output.transpose(1, 2), input)
@@ -539,28 +587,30 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
         nb_test_samples, acc_test_loss = 0, 0.0
 
-        for input in task.batches(split = 'test'):
+        for input in task.batches(split="test"):
             input = input.to(device)
             output = model(input)
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
-        train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
 
-        log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+        log_string(
+            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+        )
 
         task.produce_results(n_epoch, model)
 
     checkpoint = {
-        'nb_epochs_finished': n_epoch + 1,
-        'model_state': model.state_dict(),
-        'rng_state': torch.get_rng_state(),
+        "nb_epochs_finished": n_epoch + 1,
+        "model_state": model.state_dict(),
+        "rng_state": torch.get_rng_state(),
     }
 
     if torch.cuda.is_available():
-        checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
+        checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
 
     torch.save(checkpoint, args.checkpoint_name)
 
index f954797..a6b257c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -14,6 +14,7 @@ from torch.nn import functional as F
 
 ##############################
 
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -22,8 +23,10 @@ class WithResidual(nn.Module):
     def forward(self, x):
         return x + self.f(x)
 
+
 ##############################
 
+
 class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()
@@ -31,18 +34,20 @@ class AddPositionalEncoding(nn.Module):
 
     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
     def forward(self, x):
-        t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
-        j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
-        k = j%2
-        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
         return x + pe
 
+
 ##############################
 
+
 class QKVAttention(nn.Module):
-    def __init__(self,
-                 dim_in, dim_qk, dim_v,
-                 nb_heads = 1, causal = False, attention_dropout = 0.0):
+    def __init__(
+        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+    ):
         super().__init__()
 
         def randw(*d):
@@ -56,36 +61,47 @@ class QKVAttention(nn.Module):
         self.w_v = randw(nb_heads, dim_v, dim_in)
         self.w_o = randw(dim_v * nb_heads, dim_in)
 
-    def forward(self, x_q, x_kv = None):
-        if x_kv is None: x_kv = x_q
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
 
-        q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
-        k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
-        v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
+        v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
 
-        a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
+        a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
 
         if self.causal:
-            forbidden_attention = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
-                                < torch.arange(a.size(3), device = q.device)[None, None, None, :]
-            a = a.masked_fill(forbidden_attention, float('-inf'))
+            forbidden_attention = (
+                torch.arange(a.size(2), device=q.device)[None, None, :, None]
+                < torch.arange(a.size(3), device=q.device)[None, None, None, :]
+            )
+            a = a.masked_fill(forbidden_attention, float("-inf"))
 
-        a = a.softmax(dim = 3)
+        a = a.softmax(dim=3)
         a = F.dropout(a, self.attention_dropout, self.training)
-        y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2)
+        y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2)
 
         y = y @ self.w_o
 
         return y
 
+
 ##############################
 
+
 class MyGPT(nn.Module):
-    def __init__(self,
-                 vocabulary_size,
-                 dim_model, dim_keys, dim_hidden,
-                 nb_heads, nb_blocks,
-                 dropout = 0.0, len_max = 1e5):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
 
         super().__init__()
 
@@ -97,37 +113,38 @@ class MyGPT(nn.Module):
             AddPositionalEncoding(len_max),
         )
 
-        trunk_blocks = [ ]
+        trunk_blocks = []
 
         for _ in range(nb_blocks):
             trunk_blocks += [
                 WithResidual(
                     nn.LayerNorm((dim_model,)),
                     QKVAttention(
-                        dim_in = dim_model,
-                        dim_qk = dim_keys,
-                        dim_v = dim_model // nb_heads,
-                        nb_heads = nb_heads,
-                        causal = True, attention_dropout = dropout
+                        dim_in=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        causal=True,
+                        attention_dropout=dropout,
                     ),
                 ),
                 WithResidual(
                     nn.LayerNorm((dim_model,)),
-                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
                     nn.ReLU(),
-                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
                     nn.Dropout(dropout),
                 ),
             ]
 
         self.trunk = nn.Sequential(*trunk_blocks)
 
-        self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
 
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean = 0, std = 2e-2)
+                    m.weight.normal_(mean=0, std=2e-2)
                 elif isinstance(m, nn.LayerNorm):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
@@ -139,19 +156,23 @@ class MyGPT(nn.Module):
         x = self.readout(x)
         return x
 
+
 ######################################################################
 
-if __name__ == '__main__':
-    print('Basic check.')
+if __name__ == "__main__":
+    print("Basic check.")
 
     vocabulary_size = 10
     x = torch.randint(vocabulary_size, (25, 100))
 
     model = MyGPT(
-        vocabulary_size = vocabulary_size,
-        dim_model = 18, dim_keys = 50, dim_hidden = 100,
-        nb_heads = 2, nb_blocks = 3,
-        dropout = 0.1
+        vocabulary_size=vocabulary_size,
+        dim_model=18,
+        dim_keys=50,
+        dim_hidden=100,
+        nb_heads=2,
+        nb_blocks=3,
+        dropout=0.1,
     )
 
     y = model(x)
index 059e352..fb791fe 100755 (executable)
 
 import torch, torchvision
 
-colors = [
-    [ 255, 255, 255 ], [ 255, 0, 0 ], [ 0, 128, 0 ], [ 0, 0, 255 ], [ 255, 255, 0 ],
-    [ 0, 0, 0 ], [ 128, 0, 0 ], [ 139, 0, 0 ], [ 165, 42, 42 ], [ 178, 34, 34 ],
-    [ 220, 20, 60 ], [ 255, 99, 71 ], [ 255, 127, 80 ], [ 205, 92, 92 ], [ 240, 128, 128 ],
-    [ 233, 150, 122 ], [ 250, 128, 114 ], [ 255, 160, 122 ], [ 255, 69, 0 ], [ 255, 140, 0 ],
-    [ 255, 165, 0 ], [ 255, 215, 0 ], [ 184, 134, 11 ], [ 218, 165, 32 ], [ 238, 232, 170 ],
-    [ 189, 183, 107 ], [ 240, 230, 140 ], [ 128, 128, 0 ], [ 154, 205, 50 ], [ 85, 107, 47 ],
-    [ 107, 142, 35 ], [ 124, 252, 0 ], [ 127, 255, 0 ], [ 173, 255, 47 ], [ 0, 100, 0 ],
-    [ 34, 139, 34 ], [ 0, 255, 0 ], [ 50, 205, 50 ], [ 144, 238, 144 ], [ 152, 251, 152 ],
-    [ 143, 188, 143 ], [ 0, 250, 154 ], [ 0, 255, 127 ], [ 46, 139, 87 ], [ 102, 205, 170 ],
-    [ 60, 179, 113 ], [ 32, 178, 170 ], [ 47, 79, 79 ], [ 0, 128, 128 ], [ 0, 139, 139 ],
-    [ 0, 255, 255 ], [ 0, 255, 255 ], [ 224, 255, 255 ], [ 0, 206, 209 ], [ 64, 224, 208 ],
-    [ 72, 209, 204 ], [ 175, 238, 238 ], [ 127, 255, 212 ], [ 176, 224, 230 ], [ 95, 158, 160 ],
-    [ 70, 130, 180 ], [ 100, 149, 237 ], [ 0, 191, 255 ], [ 30, 144, 255 ], [ 173, 216, 230 ],
-    [ 135, 206, 235 ], [ 135, 206, 250 ], [ 25, 25, 112 ], [ 0, 0, 128 ], [ 0, 0, 139 ],
-    [ 0, 0, 205 ], [ 65, 105, 225 ], [ 138, 43, 226 ], [ 75, 0, 130 ], [ 72, 61, 139 ],
-    [ 106, 90, 205 ], [ 123, 104, 238 ], [ 147, 112, 219 ], [ 139, 0, 139 ], [ 148, 0, 211 ],
-    [ 153, 50, 204 ], [ 186, 85, 211 ], [ 128, 0, 128 ], [ 216, 191, 216 ], [ 221, 160, 221 ],
-    [ 238, 130, 238 ], [ 255, 0, 255 ], [ 218, 112, 214 ], [ 199, 21, 133 ], [ 219, 112, 147 ],
-    [ 255, 20, 147 ], [ 255, 105, 180 ], [ 255, 182, 193 ], [ 255, 192, 203 ], [ 250, 235, 215 ],
-    [ 245, 245, 220 ], [ 255, 228, 196 ], [ 255, 235, 205 ], [ 245, 222, 179 ], [ 255, 248, 220 ],
-    [ 255, 250, 205 ], [ 250, 250, 210 ], [ 255, 255, 224 ], [ 139, 69, 19 ], [ 160, 82, 45 ],
-    [ 210, 105, 30 ], [ 205, 133, 63 ], [ 244, 164, 96 ], [ 222, 184, 135 ], [ 210, 180, 140 ],
-    [ 188, 143, 143 ], [ 255, 228, 181 ], [ 255, 222, 173 ], [ 255, 218, 185 ], [ 255, 228, 225 ],
-    [ 255, 240, 245 ], [ 250, 240, 230 ], [ 253, 245, 230 ], [ 255, 239, 213 ], [ 255, 245, 238 ],
-    [ 245, 255, 250 ], [ 112, 128, 144 ], [ 119, 136, 153 ], [ 176, 196, 222 ], [ 230, 230, 250 ],
-    [ 255, 250, 240 ], [ 240, 248, 255 ], [ 248, 248, 255 ], [ 240, 255, 240 ], [ 255, 255, 240 ],
-    [ 240, 255, 255 ], [ 255, 250, 250 ], [ 192, 192, 192 ], [ 220, 220, 220 ], [ 245, 245, 245 ],
-]
-
-color_names = [
-    'white', 'red', 'green', 'blue', 'yellow',
-    'black', 'maroon', 'dark_red', 'brown', 'firebrick',
-    'crimson', 'tomato', 'coral', 'indian_red', 'light_coral',
-    'dark_salmon', 'salmon', 'light_salmon', 'orange_red', 'dark_orange',
-    'orange', 'gold', 'dark_golden_rod', 'golden_rod', 'pale_golden_rod',
-    'dark_khaki', 'khaki', 'olive', 'yellow_green', 'dark_olive_green',
-    'olive_drab', 'lawn_green', 'chartreuse', 'green_yellow', 'dark_green',
-    'forest_green', 'lime', 'lime_green', 'light_green', 'pale_green',
-    'dark_sea_green', 'medium_spring_green', 'spring_green', 'sea_green', 'medium_aqua_marine',
-    'medium_sea_green', 'light_sea_green', 'dark_slate_gray', 'teal', 'dark_cyan',
-    'aqua', 'cyan', 'light_cyan', 'dark_turquoise', 'turquoise',
-    'medium_turquoise', 'pale_turquoise', 'aqua_marine', 'powder_blue', 'cadet_blue',
-    'steel_blue', 'corn_flower_blue', 'deep_sky_blue', 'dodger_blue', 'light_blue',
-    'sky_blue', 'light_sky_blue', 'midnight_blue', 'navy', 'dark_blue',
-    'medium_blue', 'royal_blue', 'blue_violet', 'indigo', 'dark_slate_blue',
-    'slate_blue', 'medium_slate_blue', 'medium_purple', 'dark_magenta', 'dark_violet',
-    'dark_orchid', 'medium_orchid', 'purple', 'thistle', 'plum',
-    'violet', 'magenta', 'orchid', 'medium_violet_red', 'pale_violet_red',
-    'deep_pink', 'hot_pink', 'light_pink', 'pink', 'antique_white',
-    'beige', 'bisque', 'blanched_almond', 'wheat', 'corn_silk',
-    'lemon_chiffon', 'light_golden_rod_yellow', 'light_yellow', 'saddle_brown', 'sienna',
-    'chocolate', 'peru', 'sandy_brown', 'burly_wood', 'tan',
-    'rosy_brown', 'moccasin', 'navajo_white', 'peach_puff', 'misty_rose',
-    'lavender_blush', 'linen', 'old_lace', 'papaya_whip', 'sea_shell',
-    'mint_cream', 'slate_gray', 'light_slate_gray', 'light_steel_blue', 'lavender',
-    'floral_white', 'alice_blue', 'ghost_white', 'honeydew', 'ivory',
-    'azure', 'snow', 'silver', 'gainsboro', 'white_smoke',
-]
-
-color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] )
-color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
+color_tokens = {
+    "white": [255, 255, 255],
+    "red": [255, 0, 0],
+    "green": [0, 128, 0],
+    "blue": [0, 0, 255],
+    "yellow": [255, 255, 0],
+    "black": [0, 0, 0],
+    "maroon": [128, 0, 0],
+    "dark_red": [139, 0, 0],
+    "brown": [165, 42, 42],
+    "firebrick": [178, 34, 34],
+    "crimson": [220, 20, 60],
+    "tomato": [255, 99, 71],
+    "coral": [255, 127, 80],
+    "indian_red": [205, 92, 92],
+    "light_coral": [240, 128, 128],
+    "dark_salmon": [233, 150, 122],
+    "salmon": [250, 128, 114],
+    "light_salmon": [255, 160, 122],
+    "orange_red": [255, 69, 0],
+    "dark_orange": [255, 140, 0],
+    "orange": [255, 165, 0],
+    "gold": [255, 215, 0],
+    "dark_golden_rod": [184, 134, 11],
+    "golden_rod": [218, 165, 32],
+    "pale_golden_rod": [238, 232, 170],
+    "dark_khaki": [189, 183, 107],
+    "khaki": [240, 230, 140],
+    "olive": [128, 128, 0],
+    "yellow_green": [154, 205, 50],
+    "dark_olive_green": [85, 107, 47],
+    "olive_drab": [107, 142, 35],
+    "lawn_green": [124, 252, 0],
+    "chartreuse": [127, 255, 0],
+    "green_yellow": [173, 255, 47],
+    "dark_green": [0, 100, 0],
+    "forest_green": [34, 139, 34],
+    "lime": [0, 255, 0],
+    "lime_green": [50, 205, 50],
+    "light_green": [144, 238, 144],
+    "pale_green": [152, 251, 152],
+    "dark_sea_green": [143, 188, 143],
+    "medium_spring_green": [0, 250, 154],
+    "spring_green": [0, 255, 127],
+    "sea_green": [46, 139, 87],
+    "medium_aqua_marine": [102, 205, 170],
+    "medium_sea_green": [60, 179, 113],
+    "light_sea_green": [32, 178, 170],
+    "dark_slate_gray": [47, 79, 79],
+    "teal": [0, 128, 128],
+    "dark_cyan": [0, 139, 139],
+    "aqua": [0, 255, 255],
+    "cyan": [0, 255, 255],
+    "light_cyan": [224, 255, 255],
+    "dark_turquoise": [0, 206, 209],
+    "turquoise": [64, 224, 208],
+    "medium_turquoise": [72, 209, 204],
+    "pale_turquoise": [175, 238, 238],
+    "aqua_marine": [127, 255, 212],
+    "powder_blue": [176, 224, 230],
+    "cadet_blue": [95, 158, 160],
+    "steel_blue": [70, 130, 180],
+    "corn_flower_blue": [100, 149, 237],
+    "deep_sky_blue": [0, 191, 255],
+    "dodger_blue": [30, 144, 255],
+    "light_blue": [173, 216, 230],
+    "sky_blue": [135, 206, 235],
+    "light_sky_blue": [135, 206, 250],
+    "midnight_blue": [25, 25, 112],
+    "navy": [0, 0, 128],
+    "dark_blue": [0, 0, 139],
+    "medium_blue": [0, 0, 205],
+    "royal_blue": [65, 105, 225],
+    "blue_violet": [138, 43, 226],
+    "indigo": [75, 0, 130],
+    "dark_slate_blue": [72, 61, 139],
+    "slate_blue": [106, 90, 205],
+    "medium_slate_blue": [123, 104, 238],
+    "medium_purple": [147, 112, 219],
+    "dark_magenta": [139, 0, 139],
+    "dark_violet": [148, 0, 211],
+    "dark_orchid": [153, 50, 204],
+    "medium_orchid": [186, 85, 211],
+    "purple": [128, 0, 128],
+    "thistle": [216, 191, 216],
+    "plum": [221, 160, 221],
+    "violet": [238, 130, 238],
+    "magenta": [255, 0, 255],
+    "orchid": [218, 112, 214],
+    "medium_violet_red": [199, 21, 133],
+    "pale_violet_red": [219, 112, 147],
+    "deep_pink": [255, 20, 147],
+    "hot_pink": [255, 105, 180],
+    "light_pink": [255, 182, 193],
+    "pink": [255, 192, 203],
+    "antique_white": [250, 235, 215],
+    "beige": [245, 245, 220],
+    "bisque": [255, 228, 196],
+    "blanched_almond": [255, 235, 205],
+    "wheat": [245, 222, 179],
+    "corn_silk": [255, 248, 220],
+    "lemon_chiffon": [255, 250, 205],
+    "light_golden_rod_yellow": [250, 250, 210],
+    "light_yellow": [255, 255, 224],
+    "saddle_brown": [139, 69, 19],
+    "sienna": [160, 82, 45],
+    "chocolate": [210, 105, 30],
+    "peru": [205, 133, 63],
+    "sandy_brown": [244, 164, 96],
+    "burly_wood": [222, 184, 135],
+    "tan": [210, 180, 140],
+    "rosy_brown": [188, 143, 143],
+    "moccasin": [255, 228, 181],
+    "navajo_white": [255, 222, 173],
+    "peach_puff": [255, 218, 185],
+    "misty_rose": [255, 228, 225],
+    "lavender_blush": [255, 240, 245],
+    "linen": [250, 240, 230],
+    "old_lace": [253, 245, 230],
+    "papaya_whip": [255, 239, 213],
+    "sea_shell": [255, 245, 238],
+    "mint_cream": [245, 255, 250],
+    "slate_gray": [112, 128, 144],
+    "light_slate_gray": [119, 136, 153],
+    "light_steel_blue": [176, 196, 222],
+    "lavender": [230, 230, 250],
+    "floral_white": [255, 250, 240],
+    "alice_blue": [240, 248, 255],
+    "ghost_white": [248, 248, 255],
+    "honeydew": [240, 255, 240],
+    "ivory": [255, 255, 240],
+    "azure": [240, 255, 255],
+    "snow": [255, 250, 250],
+    "silver": [192, 192, 192],
+    "gainsboro": [220, 220, 220],
+    "white_smoke": [245, 245, 245],
+}
+
+color_id = dict([(n, k) for k, n in enumerate(color_tokens.keys())])
+color_names = dict([(k, n) for k, n in enumerate(color_tokens.keys())])
 
 ######################################################################
 
-def all_properties(height, width, nb_squares, square_i, square_j, square_c):
-    s = [ ]
-
-    for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
-        s += [ f'there is {c}' ]
 
-        if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
-        if square_i[r] < height//3: s += [ f'{c} top' ]
-        if square_j[r] >= width - width//3: s += [ f'{c} right' ]
-        if square_j[r] < width//3: s += [ f'{c} left' ]
-
-        for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
-            if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
-            if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
-            if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
-            if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+    s = []
+
+    for r, c in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
+        s += [f"there is {c}"]
+
+        if square_i[r] >= height - height // 3:
+            s += [f"{c} bottom"]
+        if square_i[r] < height // 3:
+            s += [f"{c} top"]
+        if square_j[r] >= width - width // 3:
+            s += [f"{c} right"]
+        if square_j[r] < width // 3:
+            s += [f"{c} left"]
+
+        for t, d in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
+            if square_i[r] > square_i[t]:
+                s += [f"{c} below {d}"]
+            if square_i[r] < square_i[t]:
+                s += [f"{c} above {d}"]
+            if square_j[r] > square_j[t]:
+                s += [f"{c} right of {d}"]
+            if square_j[r] < square_j[t]:
+                s += [f"{c} left of {d}"]
 
     return s
 
+
 ######################################################################
 
-def generate(nb, height, width,
-             max_nb_squares = 5, max_nb_properties = 10,
-             nb_colors = 5,
-             pruning_criterion = None):
+
+def generate(
+    nb,
+    height,
+    width,
+    max_nb_squares=5,
+    max_nb_properties=10,
+    nb_colors=5,
+    pruning_criterion=None,
+):
 
     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
 
-    descr = [ ]
+    descr = []
 
     for n in range(nb):
 
@@ -108,70 +202,77 @@ def generate(nb, height, width,
         square_position = torch.randperm(height * width)[:nb_squares]
         # color 0 is white and reserved for the background
         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
-        square_i = square_position.div(width, rounding_mode = 'floor')
+        square_i = square_position.div(width, rounding_mode="floor")
         square_j = square_position % width
 
-        img = [ 0 ] * height * width
-        for k in range(nb_squares): img[square_position[k]] = square_c[k]
+        img = torch.zeros(height * width, dtype=torch.int64)
+        for k in range(nb_squares):
+            img[square_position[k]] = square_c[k]
 
         # generates all the true properties
 
         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
 
         if pruning_criterion is not None:
-            s = list(filter(pruning_criterion,s))
+            s = list(filter(pruning_criterion, s))
 
         # pick at most max_nb_properties at random
 
         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
-        s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] )
-        s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
+        s = " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
+        s += " <img> " + " ".join([f"{color_names[n.item()]}" for n in img])
 
-        descr += [ s ]
+        descr += [s]
 
     return descr
 
+
 ######################################################################
 
+
 def descr2img(descr, height, width):
 
     if type(descr) == list:
-        return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
+        return torch.cat([descr2img(d, height, width) for d in descr], 0)
 
     def token2color(t):
         try:
             return color_tokens[t]
         except KeyError:
-            return [ 128, 128, 128 ]
+            return [128, 128, 128]
 
-    d = descr.split('<img>', 1)
-    d = d[-1] if len(d) > 1 else ''
-    d = d.strip().split(' ')[:height * width]
-    d = d + [ '<unk>' ] * (height * width - len(d))
-    d = [ token2color(t) for t in d ]
+    d = descr.split("<img>", 1)
+    d = d[-1] if len(d) > 1 else ""
+    d = d.strip().split(" ")[: height * width]
+    d = d + ["<unk>"] * (height * width - len(d))
+    d = [token2color(t) for t in d]
     img = torch.tensor(d).permute(1, 0)
     img = img.reshape(1, 3, height, width)
 
     return img
 
+
 ######################################################################
 
+
 def descr2properties(descr, height, width):
 
     if type(descr) == list:
-        return [ descr2properties(d, height, width) for d in descr ]
+        return [descr2properties(d, height, width) for d in descr]
 
-    d = descr.split('<img>', 1)
-    d = d[-1] if len(d) > 1 else ''
-    d = d.strip().split(' ')[:height * width]
+    d = descr.split("<img>", 1)
+    d = d[-1] if len(d) > 1 else ""
+    d = d.strip().split(" ")[: height * width]
 
     seen = {}
-    if len(d) != height * width: return []
+    if len(d) != height * width:
+        return []
 
     for k, x in enumerate(d):
         if x != color_names[0]:
             if x in color_tokens:
-                if x in seen: return []
+                if x in seen:
+                    return []
             else:
                 return []
             seen[x] = (color_id[x], k // width, k % width)
@@ -190,16 +291,19 @@ def descr2properties(descr, height, width):
 
     return s
 
+
 ######################################################################
 
+
 def nb_properties(descr, height, width):
     if type(descr) == list:
-        return [ nb_properties(d, height, width) for d in descr ]
+        return [nb_properties(d, height, width) for d in descr]
 
-    d = descr.split('<img>', 1)
-    if len(d) == 0: return 0
-    d = d[0].strip().split('<sep>')
-    d = [ x.strip() for x in d ]
+    d = descr.split("<img>", 1)
+    if len(d) == 0:
+        return 0
+    d = d[0].strip().split("<sep>")
+    d = [x.strip() for x in d]
 
     requested_properties = set(d)
     all_properties = set(descr2properties(descr, height, width))
@@ -207,30 +311,36 @@ def nb_properties(descr, height, width):
 
     return (len(requested_properties), len(all_properties), len(missing_properties))
 
+
 ######################################################################
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     descr = generate(
-        nb = 5, height = 12, width = 16,
-        pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+        nb=5,
+        height=12,
+        width=16,
+        pruning_criterion=lambda s: not (
+            "green" in s and ("right" in s or "left" in s)
+        ),
     )
 
-    print(descr2properties(descr, height = 12, width = 16))
-    print(nb_properties(descr, height = 12, width = 16))
+    print(descr2properties(descr, height=12, width=16))
+    print(nb_properties(descr, height=12, width=16))
 
-    with open('picoclvr_example.txt', 'w') as f:
+    with open("picoclvr_example.txt", "w") as f:
         for d in descr:
-            f.write(f'{d}\n\n')
+            f.write(f"{d}\n\n")
 
-    img = descr2img(descr, height = 12, width = 16)
-    torchvision.utils.save_image(img / 255.,
-                                 'picoclvr_example.png', nrow = 16, pad_value = 0.8)
+    img = descr2img(descr, height=12, width=16)
+    torchvision.utils.save_image(
+        img / 255.0, "picoclvr_example.png", nrow=16, pad_value=0.8
+    )
 
     import time
 
     start_time = time.perf_counter()
-    descr = generate(nb = 1000, height = 12, width = 16)
+    descr = generate(nb=1000, height=12, width=16)
     end_time = time.perf_counter()
-    print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
+    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
 
 ######################################################################