Added args.learning_rate_end for an exponential decay.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index 1b011a2..f6934b7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -42,7 +42,10 @@ parser.add_argument('--optim',
                     type = str, default = 'adam')
 
 parser.add_argument('--learning_rate',
-                    type = float, default = 1e-4)
+                    type = float, default = 1e-3)
+
+parser.add_argument('--learning_rate_end',
+                    type = float, default = 1e-6)
 
 parser.add_argument('--dim_model',
                     type = int, default = 512)
@@ -156,13 +159,20 @@ import picoclvr
 
 class TaskPicoCLVR(Task):
 
+    # Make a tensor from a list of strings
     def tensorize(self, descr):
-        descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in descr ])
-        #descr = [ [ '<nul>' ] * (l - len(s)) + s for s in descr ]
-        descr = [ s + [ '<nul>' ] * (l - len(s)) for s in descr ]
-        t = [ [ self.token2id[u] for u in s ] for s in descr ]
-        return torch.tensor(t, device = self.device)
+        token_descr = [ s.strip().split(' ') for s in descr ]
+        l = max([ len(s) for s in token_descr ])
+        #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
+        token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
+        id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
+        return torch.tensor(id_descr, device = self.device)
+
+    def trim(self, x, token = '<nul>'):
+        n = self.token2id[token]
+        i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
+        a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+        return x[:, a:b]
 
     def __init__(self, batch_size,
                  height, width, nb_colors = 5,
@@ -181,6 +191,7 @@ class TaskPicoCLVR(Task):
         self.device = device
         nb = args.data_size if args.data_size > 0 else 250000
 
+        log_string(f'generating {nb} samples (can take some time)')
         self.train_descr = generate_descr((nb * 4) // 5)
         self.test_descr = generate_descr((nb * 1) // 5)
 
@@ -200,13 +211,13 @@ class TaskPicoCLVR(Task):
         assert split in { 'train', 'test' }
         input = self.train_input if split == 'train' else self.test_input
         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
-            yield batch
+            yield self.trim(batch)
 
     def vocabulary_size(self):
         return len(self.token2id)
 
     def produce_results(self, n_epoch, model):
-        nb_tokens = self.height * self.width + 3
+        nb_tokens_to_generate = self.height * self.width + 3
         result_descr = [ ]
         nb_per_primer = 8
 
@@ -217,15 +228,26 @@ class TaskPicoCLVR(Task):
                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
         ]:
 
-            for k in range(nb_per_primer):
-                results = autoregression(
-                    model, self.batch_size,
-                    nb_samples = 1, nb_tokens_to_generate = nb_tokens,
-                    primer = self.tensorize([ primer_descr ]),
-                    device = self.device
-                )
-                r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
-                result_descr.append(r)
+            results = autoregression(
+                model,
+                self.batch_size,
+                nb_samples = nb_per_primer,
+                nb_tokens_to_generate = nb_tokens_to_generate,
+                primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
+                device = self.device
+            )
+
+            l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
+            result_descr += l
+
+        np = picoclvr.nb_properties(
+            result_descr,
+            height = self.height, width = self.width
+        )
+
+        nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+        log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
 
         img = [
             picoclvr.descr2img(d, height = self.height, width = self.width)
@@ -240,15 +262,6 @@ class TaskPicoCLVR(Task):
         )
         log_string(f'wrote {image_name}')
 
-        np = picoclvr.nb_properties(
-            result_descr,
-            height = self.height, width = self.width
-        )
-
-        nb_requested_properties, _, nb_missing_properties = zip(*np)
-
-        log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
-
 ######################################################################
 
 class TaskWiki103(Task):
@@ -281,6 +294,7 @@ class TaskWiki103(Task):
 
         self.vocab.set_default_index(self.vocab[ '<unk>' ])
 
+    # makes a tensor from a list of list of tokens
     def tensorize(self, s):
         a = max(len(x) for x in s)
         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
@@ -419,17 +433,6 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
 
 ######################################################################
 
-if args.optim == 'sgd':
-    optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
-    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
-    optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
-    raise ValueError(f'Unknown optimizer {args.optim}.')
-
-######################################################################
-
 nb_epochs_finished = 0
 
 if args.no_checkpoint:
@@ -437,10 +440,12 @@ if args.no_checkpoint:
 
 else:
     try:
-        checkpoint = torch.load(args.checkpoint_name, map_location = device)
+        checkpoint = torch.load(args.checkpoint_name)
         nb_epochs_finished = checkpoint['nb_epochs_finished']
         model.load_state_dict(checkpoint['model_state'])
-        optimizer.load_state_dict(checkpoint['optimizer_state'])
+        torch.set_rng_state(checkpoint['rng_state'])
+        if torch.cuda.is_available():
+            torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
 
     except FileNotFoundError:
@@ -460,9 +465,25 @@ for input in task.batches(split = 'train'):
 token_probas = token_count / token_count.sum()
 entropy = -torch.xlogy(token_probas, token_probas).sum()
 train_set_perplexity = math.exp(entropy)
-#log_string(f'train set perplexity {train_set_perplexity}')
 
-for k in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+
+    if args.learning_rate_end < 0:
+        lr = args.learning_rate
+    else:
+        u = n_epoch / (nb_epochs - 1)
+        lr = math.exp((1 - u) * math.log(args.learning_rate) +
+                      u * math.log(args.learning_rate_end))
+        log_string(f'learning_rate {lr}')
+
+    if args.optim == 'sgd':
+        optimizer = torch.optim.SGD(model.parameters(), lr = lr)
+    elif args.optim == 'adam':
+        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+    elif args.optim == 'adamw':
+        optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
+    else:
+        raise ValueError(f'Unknown optimizer {args.optim}.')
 
     model.train()
 
@@ -495,16 +516,19 @@ for k in range(nb_epochs_finished, nb_epochs):
         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
 
-        log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+        log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
 
-        task.produce_results(k, model)
+        task.produce_results(n_epoch, model)
 
     checkpoint = {
-        'nb_epochs_finished': k + 1,
+        'nb_epochs_finished': n_epoch + 1,
         'model_state': model.state_dict(),
-        'optimizer_state': optimizer.state_dict()
+        'rng_state': torch.get_rng_state(),
     }
 
+    if torch.cuda.is_available():
+        checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
+
     torch.save(checkpoint, args.checkpoint_name)
 
 ######################################################################