Update.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index c810eef..f65bb8e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -18,15 +18,11 @@ import mygpt
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 ######################################################################
-
 parser = argparse.ArgumentParser(description = 'My own GPT.')
 
 parser.add_argument('--log_filename',
                     type = str, default = 'train.log')
 
-parser.add_argument('--download',
-                    action='store_true', default = False)
-
 parser.add_argument('--seed',
                     type = int, default = 0)
 
@@ -78,8 +74,8 @@ parser.add_argument('--checkpoint_name',
 ##############################
 # picoclvr options
 
-parser.add_argument('--picoclvr_many_colors',
-                    action='store_true', default = False)
+parser.add_argument('--picoclvr_nb_colors',
+                    type = int, default = 5)
 
 parser.add_argument('--picoclvr_height',
                     type = int, default = 12)
@@ -113,22 +109,34 @@ for n in vars(args):
 
 ######################################################################
 
-def produce_results(
-        self,
-        model, nb_samples, nb_tokens_to_generate, starting_input = None,
-        device = 'cpu'
+def autoregression(
+        model, batch_size,
+        nb_samples, nb_tokens_to_generate, primer = None,
+        device = torch.device('cpu')
 ):
-    results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
-    for input in results.split(self.batch_size):
-        for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+    results = torch.zeros(
+        nb_samples, nb_tokens_to_generate,
+        dtype = torch.int64, device = device
+    )
+
+    if primer is None:
+        first = 0
+    else:
+        first = primer.size(1)
+        results = torch.cat((primer, results), 1)
+
+    for input in results.split(batch_size):
+        for s in range(first, input.size(1)):
             output = model(input)
             logits = output[:, s]
             if args.synthesis_sampling:
                 dist = torch.distributions.categorical.Categorical(logits = logits)
-                t = dist.sample()
+                t_next = dist.sample()
             else:
-                t = logits.argmax(1)
-            input[:, s + 1] = t
+                t_next = logits.argmax(1)
+            input[:, s] = t_next
+
+    return results
 
 ######################################################################
 
@@ -139,7 +147,7 @@ class Task:
     def vocabulary_size(self):
         pass
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
+    def produce_results(self, n_epoch, model):
         pass
 
 ######################################################################
@@ -148,92 +156,101 @@ import picoclvr
 
 class TaskPicoCLVR(Task):
 
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [ s.strip().split(' ') for s in descr ]
+        l = max([ len(s) for s in token_descr ])
+        #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
+        token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
+        id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
+        return torch.tensor(id_descr, device = self.device)
+
+    def trim(self, x, token = '<nul>'):
+        n = self.token2id[token]
+        i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
+        a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+        return x[:, a:b]
+
     def __init__(self, batch_size,
-                 height, width, many_colors = False,
+                 height, width, nb_colors = 5,
                  device = torch.device('cpu')):
 
         def generate_descr(nb):
-            descr = picoclvr.generate(
+            return picoclvr.generate(
                 nb,
                 height = self.height, width = self.width,
-                many_colors = many_colors
+                nb_colors = nb_colors
             )
 
-            descr = [ s.strip().split(' ') for s in descr ]
-            l = max([ len(s) for s in descr ])
-            descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
-
-            return descr
-
         self.height = height
         self.width = width
         self.batch_size = batch_size
         self.device = device
         nb = args.data_size if args.data_size > 0 else 250000
 
+        log_string('generating {nb} samples (can take some time)')
         self.train_descr = generate_descr((nb * 4) // 5)
         self.test_descr = generate_descr((nb * 1) // 5)
 
         # Build the tokenizer
-        tokens = set()
+        tokens = { '<nul>' }
         for d in [ self.train_descr, self.test_descr ]:
             for s in d:
-                for t in s: tokens.add(t)
+                for t in s.strip().split(' '): tokens.add(t)
         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
 
-        t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
-        self.train_input = torch.tensor(t, device = self.device)
-        t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
-        self.test_input = torch.tensor(t, device = self.device)
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
 
     def batches(self, split = 'train'):
         assert split in { 'train', 'test' }
-        if split == 'train':
-            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
-                yield batch
-        else:
-            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
-                yield batch
+        input = self.train_input if split == 'train' else self.test_input
+        for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
+            yield self.trim(batch)
 
     def vocabulary_size(self):
         return len(self.token2id)
 
-    def generate(self, primer, model, nb_tokens):
-        t_primer = primer.strip().split(' ')
-        t_generated = [ ]
-
-        for j in range(nb_tokens):
-            t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
-            input = torch.tensor(t, device = self.device)
-            output = model(input)
-            logits = output[0, -1]
-            if args.synthesis_sampling:
-                dist = torch.distributions.categorical.Categorical(logits = logits)
-                t = dist.sample()
-            else:
-                t = logits.argmax()
-            t_generated.append(self.id2token[t.item()])
-
-        return ' '.join(t_primer + t_generated)
-
-    def produce_results(self, n_epoch, model, nb_tokens = None):
-        if nb_tokens is None:
-            nb_tokens = self.height * self.width + 3
-        descr = [ ]
+    def produce_results(self, n_epoch, model):
+        nb_tokens_to_generate = self.height * self.width + 3
+        result_descr = [ ]
         nb_per_primer = 8
 
-        for primer in [
+        for primer_descr in [
                 'red above green <sep> green top <sep> blue right of red <img>',
                 'there is red <sep> there is yellow <sep> there is blue <img>',
                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
         ]:
 
-            for k in range(nb_per_primer):
-                descr.append(self.generate(primer, model, nb_tokens))
+            results = autoregression(
+                model,
+                self.batch_size,
+                nb_samples = nb_per_primer,
+                nb_tokens_to_generate = nb_tokens_to_generate,
+                primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
+                device = self.device
+            )
+
+            l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
+            result_descr += l
+
+        np = picoclvr.nb_properties(
+            result_descr,
+            height = self.height, width = self.width
+        )
+
+        nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+        log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
+
+        img = [
+            picoclvr.descr2img(d, height = self.height, width = self.width)
+            for d in result_descr
+        ]
 
-        img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
         img = torch.cat(img, 0)
         image_name = f'result_picoclvr_{n_epoch:04d}.png'
         torchvision.utils.save_image(
@@ -242,15 +259,6 @@ class TaskPicoCLVR(Task):
         )
         log_string(f'wrote {image_name}')
 
-        nb_missing = sum( [
-            x[2] for x in picoclvr.nb_missing_properties(
-                descr,
-                height = self.height, width = self.width
-            )
-        ] )
-
-        log_string(f'nb_missing {nb_missing / len(descr):.02f}')
-
 ######################################################################
 
 class TaskWiki103(Task):
@@ -277,15 +285,16 @@ class TaskWiki103(Task):
 
         self.vocab = torchtext.vocab.build_vocab_from_iterator(
             yield_tokens(),
-            specials = [ '<unk>', '<non>' ],
+            specials = [ '<unk>', '<nul>' ],
             min_freq = self.min_freq
         )
 
         self.vocab.set_default_index(self.vocab[ '<unk>' ])
 
+    # makes a tensor from a list of list of tokens
     def tensorize(self, s):
         a = max(len(x) for x in s)
-        return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
+        return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
 
     def yield_batches(self, ds):
         s = [ ]
@@ -312,7 +321,8 @@ class TaskWiki103(Task):
     def vocabulary_size(self):
         return len(self.vocab)
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
+    def produce_results(self, n_epoch, model):
+        nb_tokens = 50
         file_name = f'result_wiki103_{n_epoch:04d}.txt'
 
         with open(file_name, 'w') as outfile:
@@ -333,15 +343,16 @@ class TaskWiki103(Task):
                  for j in range(nb_tokens):
 
                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
                      if args.synthesis_sampling:
                          dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t = dist.sample()
+                         t_next = dist.sample()
                      else:
-                         t = logits.argmax()
-                     t_generated.append(self.vocab.lookup_token(t))
-                     if t_generated[-1] == '<non>': break
+                         t_next = logits.argmax()
+                     t_generated.append(self.vocab.lookup_token(t_next))
+                     if t_generated[-1] == '<nul>': break
 
                  s = ' '.join(t_generated)
 
@@ -372,19 +383,9 @@ class TaskMNIST(Task):
     def vocabulary_size(self):
         return 256
 
-    def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
-        for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
-                output = model(input)
-                logits = output[:, s]
-                if args.synthesis_sampling:
-                    dist = torch.distributions.categorical.Categorical(logits = logits)
-                    t = dist.sample()
-                else:
-                    t = logits.argmax(1)
-                input[:, s] = t
-
+    def produce_results(self, n_epoch, model):
+        nb_samples = 64
+        results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
                                      image_name, nrow = 16, pad_value = 0.8)
@@ -405,7 +406,7 @@ elif args.data == 'picoclvr':
     task = TaskPicoCLVR(batch_size = args.batch_size,
                         height = args.picoclvr_height,
                         width = args.picoclvr_width,
-                        many_colors = args.picoclvr_many_colors,
+                        nb_colors = args.picoclvr_nb_colors,
                         device = device)
 else:
     raise ValueError(f'Unknown dataset {args.data}.')
@@ -443,7 +444,7 @@ else:
 nb_epochs_finished = 0
 
 if args.no_checkpoint:
-    log_string(f'Not trying to load checkpoint.')
+    log_string(f'not trying to load checkpoint.')
 
 else:
     try:
@@ -451,13 +452,13 @@ else:
         nb_epochs_finished = checkpoint['nb_epochs_finished']
         model.load_state_dict(checkpoint['model_state'])
         optimizer.load_state_dict(checkpoint['optimizer_state'])
-        log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+        log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
 
     except FileNotFoundError:
-        log_string('Starting from scratch.')
+        log_string('starting from scratch.')
 
     except:
-        log_string('Error when loading the checkpoint.')
+        log_string('error when loading the checkpoint.')
         exit(1)
 
 ######################################################################
@@ -468,9 +469,8 @@ token_count = 0
 for input in task.batches(split = 'train'):
     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
 token_probas = token_count / token_count.sum()
-h = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(h)
-log_string(f'Train set perplexity {train_set_perplexity}')
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
 
 for k in range(nb_epochs_finished, nb_epochs):
 
@@ -498,14 +498,14 @@ for k in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split = 'test'):
             input = input.to(device)
             output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
 
-        log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
+        log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
 
         task.produce_results(k, model)