OCDC
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index a6940f1..83227bb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -18,20 +18,16 @@ import mygpt
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 ######################################################################
-
 parser = argparse.ArgumentParser(description = 'My own GPT.')
 
 parser.add_argument('--log_filename',
                     type = str, default = 'train.log')
 
-parser.add_argument('--download',
-                    type = bool, default = False)
-
 parser.add_argument('--seed',
                     type = int, default = 0)
 
 parser.add_argument('--nb_epochs',
-                    type = int, default = 100)
+                    type = int, default = -1)
 
 parser.add_argument('--batch_size',
                     type = int, default = 25)
@@ -67,7 +63,25 @@ parser.add_argument('--dropout',
                     type = float, default = 0.1)
 
 parser.add_argument('--synthesis_sampling',
-                    type = bool, default = True)
+                    action='store_true', default = True)
+
+parser.add_argument('--no_checkpoint',
+                    action='store_true', default = False)
+
+parser.add_argument('--checkpoint_name',
+                    type = str, default = 'checkpoint.pth')
+
+##############################
+# picoclvr options
+
+parser.add_argument('--picoclvr_nb_colors',
+                    type = int, default = 5)
+
+parser.add_argument('--picoclvr_height',
+                    type = int, default = 12)
+
+parser.add_argument('--picoclvr_width',
+                    type = int, default = 16)
 
 ######################################################################
 
@@ -95,6 +109,37 @@ for n in vars(args):
 
 ######################################################################
 
+def autoregression(
+        model, batch_size,
+        nb_samples, nb_tokens_to_generate, primer = None,
+        device = torch.device('cpu')
+):
+    results = torch.zeros(
+        nb_samples, nb_tokens_to_generate,
+        dtype = torch.int64, device = device
+    )
+
+    if primer is None:
+        first = 0
+    else:
+        first = primer.size(1)
+        results = torch.cat((primer, results), 1)
+
+    for input in results.split(batch_size):
+        for s in range(first, input.size(1)):
+            output = model(input)
+            logits = output[:, s]
+            if args.synthesis_sampling:
+                dist = torch.distributions.categorical.Categorical(logits = logits)
+                t_next = dist.sample()
+            else:
+                t_next = logits.argmax(1)
+            input[:, s] = t_next
+
+    return results
+
+######################################################################
+
 class Task:
     def batches(self, split = 'train'):
         pass
@@ -102,7 +147,7 @@ class Task:
     def vocabulary_size(self):
         pass
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
+    def produce_results(self, n_epoch, model):
         pass
 
 ######################################################################
@@ -111,75 +156,107 @@ import picoclvr
 
 class TaskPicoCLVR(Task):
 
-    def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [ s.strip().split(' ') for s in descr ]
+        l = max([ len(s) for s in token_descr ])
+        #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
+        token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
+        id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
+        return torch.tensor(id_descr, device = self.device)
+
+    def trim(self, x, token = '<nul>'):
+        n = self.token2id[token]
+        i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
+        a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+        return x[:, a:b]
+
+    def __init__(self, batch_size,
+                 height, width, nb_colors = 5,
+                 device = torch.device('cpu')):
+
+        def generate_descr(nb):
+            return picoclvr.generate(
+                nb,
+                height = self.height, width = self.width,
+                nb_colors = nb_colors
+            )
+
+        self.height = height
+        self.width = width
         self.batch_size = batch_size
         self.device = device
         nb = args.data_size if args.data_size > 0 else 250000
 
-        descr = picoclvr.generate(nb, height = height, width = width)
-        descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in descr ])
-        descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
+        self.train_descr = generate_descr((nb * 4) // 5)
+        self.test_descr = generate_descr((nb * 1) // 5)
 
-        tokens = set()
-        for s in descr:
-            for t in s: tokens.add(t)
+        # Build the tokenizer
+        tokens = { '<nul>' }
+        for d in [ self.train_descr, self.test_descr ]:
+            for s in d:
+                for t in s.strip().split(' '): tokens.add(t)
         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
 
-        t = [ [ self.token2id[u] for u in s ] for s in descr ]
-        data_input = torch.tensor(t, device = self.device)
-
-        self.test_input = data_input[:nb // 5]
-        self.train_input = data_input[nb // 5:]
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
 
     def batches(self, split = 'train'):
         assert split in { 'train', 'test' }
-        if split == 'train':
-            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
-                yield batch
-        else:
-            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
-                yield batch
+        input = self.train_input if split == 'train' else self.test_input
+        for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
+            yield self.trim(batch)
 
     def vocabulary_size(self):
         return len(self.token2id)
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        img = [ ]
+    def produce_results(self, n_epoch, model):
+        nb_tokens_to_generate = self.height * self.width + 3
+        result_descr = [ ]
         nb_per_primer = 8
 
-        for primer in [
+        for primer_descr in [
                 'red above green <sep> green top <sep> blue right of red <img>',
                 'there is red <sep> there is yellow <sep> there is blue <img>',
                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
         ]:
 
-            for k in range(nb_per_primer):
-                t_primer = primer.strip().split(' ')
-                t_generated = [ ]
-
-                for j in range(nb_tokens):
-                    t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
-                    input = torch.tensor(t, device = self.device)
-                    output = model(input)
-                    logits = output[0, -1]
-                    if args.synthesis_sampling:
-                        dist = torch.distributions.categorical.Categorical(logits = logits)
-                        t = dist.sample()
-                    else:
-                        t = logits.argmax()
-                    t_generated.append(self.id2token[t.item()])
-
-                descr = [ ' '.join(t_primer + t_generated) ]
-                img += [ picoclvr.descr2img(descr) ]
+            results = autoregression(
+                model,
+                self.batch_size,
+                nb_samples = nb_per_primer,
+                nb_tokens_to_generate = nb_tokens_to_generate,
+                primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
+                device = self.device
+            )
+
+            l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
+            result_descr += l
+
+        np = picoclvr.nb_properties(
+            result_descr,
+            height = self.height, width = self.width
+        )
+
+        nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+        log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
+
+        img = [
+            picoclvr.descr2img(d, height = self.height, width = self.width)
+            for d in result_descr
+        ]
 
         img = torch.cat(img, 0)
-        file_name = f'result_picoclvr_{n_epoch:04d}.png'
-        torchvision.utils.save_image(img / 255.,
-                                     file_name, nrow = nb_per_primer, pad_value = 0.8)
-        log_string(f'wrote {file_name}')
+        image_name = f'result_picoclvr_{n_epoch:04d}.png'
+        torchvision.utils.save_image(
+            img / 255.,
+            image_name, nrow = nb_per_primer, pad_value = 0.8
+        )
+        log_string(f'wrote {image_name}')
 
 ######################################################################
 
@@ -207,15 +284,16 @@ class TaskWiki103(Task):
 
         self.vocab = torchtext.vocab.build_vocab_from_iterator(
             yield_tokens(),
-            specials = [ '<unk>', '<non>' ],
+            specials = [ '<unk>', '<nul>' ],
             min_freq = self.min_freq
         )
 
         self.vocab.set_default_index(self.vocab[ '<unk>' ])
 
+    # makes a tensor from a list of list of tokens
     def tensorize(self, s):
         a = max(len(x) for x in s)
-        return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
+        return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
 
     def yield_batches(self, ds):
         s = [ ]
@@ -237,12 +315,13 @@ class TaskWiki103(Task):
         if args.data_size > 0:
             data_iter = itertools.islice(data_iter, args.data_size)
 
-        return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
+        return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
 
     def vocabulary_size(self):
         return len(self.vocab)
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
+    def produce_results(self, n_epoch, model):
+        nb_tokens = 50
         file_name = f'result_wiki103_{n_epoch:04d}.txt'
 
         with open(file_name, 'w') as outfile:
@@ -263,15 +342,16 @@ class TaskWiki103(Task):
                  for j in range(nb_tokens):
 
                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
                      if args.synthesis_sampling:
                          dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t = dist.sample()
+                         t_next = dist.sample()
                      else:
-                         t = logits.argmax()
-                     t_generated.append(self.vocab.lookup_token(t))
-                     if t_generated[-1] == '<non>': break
+                         t_next = logits.argmax()
+                     t_generated.append(self.vocab.lookup_token(t_next))
+                     if t_generated[-1] == '<nul>': break
 
                  s = ' '.join(t_generated)
 
@@ -296,25 +376,15 @@ class TaskMNIST(Task):
         data_input = data_set.data.view(-1, 28 * 28).long()
         if args.data_size >= 0:
             data_input = data_input[:args.data_size]
-        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
+        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
             yield batch
 
     def vocabulary_size(self):
         return 256
 
-    def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
-        for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
-                output = model(input)
-                logits = output[:, s]
-                if args.synthesis_sampling:
-                    dist = torch.distributions.categorical.Categorical(logits = logits)
-                    t = dist.sample()
-                else:
-                    t = logits.argmax(1)
-                input[:, s + 1] = t
-
+    def produce_results(self, n_epoch, model):
+        nb_samples = 64
+        results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
                                      image_name, nrow = 16, pad_value = 0.8)
@@ -322,27 +392,21 @@ class TaskMNIST(Task):
 
 ######################################################################
 
-def check_causality(model):
-    #m = model[1:]
-    input = torch.rand(1, 5, dim_model).requires_grad_()
-    output = m(input)
-    a = torch.zeros(output.size(1), input.size(1))
-    for k in range(output.size(1)):
-        for d in range(output.size(2)):
-            g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
-            a[k] += g.squeeze(0).pow(2).sum(1)
-    print(a)
-
-######################################################################
-
 log_string(f'device {device}')
 
 if args.data == 'wiki103':
+    nb_epochs_default = 10
     task = TaskWiki103(batch_size = args.batch_size, device = device)
 elif args.data == 'mnist':
+    nb_epochs_default = 25
     task = TaskMNIST(batch_size = args.batch_size, device = device)
 elif args.data == 'picoclvr':
-    task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
+    nb_epochs_default = 10
+    task = TaskPicoCLVR(batch_size = args.batch_size,
+                        height = args.picoclvr_height,
+                        width = args.picoclvr_width,
+                        nb_colors = args.picoclvr_nb_colors,
+                        device = device)
 else:
     raise ValueError(f'Unknown dataset {args.data}.')
 
@@ -358,11 +422,11 @@ model = mygpt.MyGPT(
     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
 )
 
+model.to(device)
+
 nb_parameters = sum(p.numel() for p in model.parameters())
 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
 
-model.to(device)
-
 ######################################################################
 
 if args.optim == 'sgd':
@@ -374,7 +438,40 @@ elif args.optim == 'adamw':
 else:
     raise ValueError(f'Unknown optimizer {args.optim}.')
 
-for k in range(args.nb_epochs):
+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+    log_string(f'not trying to load checkpoint.')
+
+else:
+    try:
+        checkpoint = torch.load(args.checkpoint_name, map_location = device)
+        nb_epochs_finished = checkpoint['nb_epochs_finished']
+        model.load_state_dict(checkpoint['model_state'])
+        optimizer.load_state_dict(checkpoint['optimizer_state'])
+        log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
+
+    except FileNotFoundError:
+        log_string('starting from scratch.')
+
+    except:
+        log_string('error when loading the checkpoint.')
+        exit(1)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split = 'train'):
+    token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+for k in range(nb_epochs_finished, nb_epochs):
 
     model.train()
 
@@ -383,7 +480,7 @@ for k in range(args.nb_epochs):
     for input in task.batches(split = 'train'):
         input = input.to(device)
         output = model(input)
-        loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+        loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
 
@@ -400,15 +497,23 @@ for k in range(args.nb_epochs):
         for input in task.batches(split = 'test'):
             input = input.to(device)
             output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
 
-        log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
+        log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
 
         task.produce_results(k, model)
 
+    checkpoint = {
+        'nb_epochs_finished': k + 1,
+        'model_state': model.state_dict(),
+        'optimizer_state': optimizer.state_dict()
+    }
+
+    torch.save(checkpoint, args.checkpoint_name)
+
 ######################################################################