Update.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index e973291..cd0e1ea 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -24,9 +24,6 @@ parser = argparse.ArgumentParser(description = 'My own GPT.')
 parser.add_argument('--log_filename',
                     type = str, default = 'train.log')
 
-parser.add_argument('--download',
-                    action='store_true', default = False)
-
 parser.add_argument('--seed',
                     type = int, default = 0)
 
@@ -78,8 +75,8 @@ parser.add_argument('--checkpoint_name',
 ##############################
 # picoclvr options
 
-parser.add_argument('--picoclvr_many_colors',
-                    action='store_true', default = False)
+parser.add_argument('--picoclvr_nb_colors',
+                    type = int, default = 5)
 
 parser.add_argument('--picoclvr_height',
                     type = int, default = 12)
@@ -113,22 +110,34 @@ for n in vars(args):
 
 ######################################################################
 
-def produce_results(
-        self,
-        model, nb_samples, nb_tokens_to_generate, starting_input = None,
-        device = 'cpu'
+def autoregression(
+        model,
+        nb_samples, nb_tokens_to_generate, starting_input = None,
+        device = torch.device('cpu')
 ):
-    results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
-    for input in results.split(self.batch_size):
-        for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+    results = torch.zeros(
+        nb_samples, nb_tokens_to_generate,
+        dtype = torch.int64, device = device
+    )
+
+    if starting_input is None:
+        first = 0
+    else:
+        first = starting_input.size(1)
+        results = torch.cat((starting_input, results), 1)
+
+    for input in results.split(args.batch_size):
+        for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
             output = model(input)
             logits = output[:, s]
             if args.synthesis_sampling:
                 dist = torch.distributions.categorical.Categorical(logits = logits)
-                t = dist.sample()
+                t_next = dist.sample()
             else:
-                t = logits.argmax(1)
-            input[:, s + 1] = t
+                t_next = logits.argmax(1)
+            input[:, s] = t_next
+
+    return results
 
 ######################################################################
 
@@ -149,14 +158,14 @@ import picoclvr
 class TaskPicoCLVR(Task):
 
     def __init__(self, batch_size,
-                 height, width, many_colors = False,
+                 height, width, nb_colors = 5,
                  device = torch.device('cpu')):
 
         def generate_descr(nb):
             descr = picoclvr.generate(
                 nb,
                 height = self.height, width = self.width,
-                many_colors = many_colors
+                nb_colors = nb_colors
             )
 
             descr = [ s.strip().split(' ') for s in descr ]
@@ -204,16 +213,17 @@ class TaskPicoCLVR(Task):
         t_generated = [ ]
 
         for j in range(nb_tokens):
-            t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ]
+            t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
             input = torch.tensor(t, device = self.device)
+            input = F.pad(input, (0, 1)) # Add the next token, the one to predict
             output = model(input)
             logits = output[0, -1]
             if args.synthesis_sampling:
                 dist = torch.distributions.categorical.Categorical(logits = logits)
-                t = dist.sample()
+                t_next = dist.sample()
             else:
-                t = logits.argmax()
-            t_generated.append(self.id2token[t.item()])
+                t_next = logits.argmax()
+            t_generated.append(self.id2token[t_next.item()])
 
         return ' '.join(t_primer + t_generated)
 
@@ -333,14 +343,15 @@ class TaskWiki103(Task):
                  for j in range(nb_tokens):
 
                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
                      if args.synthesis_sampling:
                          dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t = dist.sample()
+                         t_next = dist.sample()
                      else:
-                         t = logits.argmax()
-                     t_generated.append(self.vocab.lookup_token(t))
+                         t_next = logits.argmax()
+                     t_generated.append(self.vocab.lookup_token(t_next))
                      if t_generated[-1] == '<non>': break
 
                  s = ' '.join(t_generated)
@@ -373,18 +384,7 @@ class TaskMNIST(Task):
         return 256
 
     def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
-        for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
-                output = model(input)
-                logits = output[:, s]
-                if args.synthesis_sampling:
-                    dist = torch.distributions.categorical.Categorical(logits = logits)
-                    t = dist.sample()
-                else:
-                    t = logits.argmax(1)
-                input[:, s] = t
-
+        results = autoregression(model, nb_samples, 28 * 28, device = self.device)
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
                                      image_name, nrow = 16, pad_value = 0.8)
@@ -405,7 +405,7 @@ elif args.data == 'picoclvr':
     task = TaskPicoCLVR(batch_size = args.batch_size,
                         height = args.picoclvr_height,
                         width = args.picoclvr_width,
-                        many_colors = args.picoclvr_many_colors,
+                        nb_colors = args.picoclvr_nb_colors,
                         device = device)
 else:
     raise ValueError(f'Unknown dataset {args.data}.')
@@ -443,7 +443,7 @@ else:
 nb_epochs_finished = 0
 
 if args.no_checkpoint:
-    log_string(f'Not trying to load checkpoint.')
+    log_string(f'not trying to load checkpoint.')
 
 else:
     try:
@@ -451,13 +451,13 @@ else:
         nb_epochs_finished = checkpoint['nb_epochs_finished']
         model.load_state_dict(checkpoint['model_state'])
         optimizer.load_state_dict(checkpoint['optimizer_state'])
-        log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+        log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
 
     except FileNotFoundError:
-        log_string('Starting from scratch.')
+        log_string('starting from scratch.')
 
     except:
-        log_string('Error when loading the checkpoint.')
+        log_string('error when loading the checkpoint.')
         exit(1)
 
 ######################################################################
@@ -470,7 +470,7 @@ for input in task.batches(split = 'train'):
 token_probas = token_count / token_count.sum()
 h = -torch.xlogy(token_probas, token_probas).sum()
 train_set_perplexity = math.exp(h)
-log_string(f'Train set perplexity {train_set_perplexity}')
+log_string(f'train set perplexity {train_set_perplexity}')
 
 for k in range(nb_epochs_finished, nb_epochs):