Moved the input/output shift in the forward of the model.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:35:19 +0000 (12:35 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:35:19 +0000 (12:35 +0200)
main.py
mygpt.py
readme.txt

diff --git a/main.py b/main.py
index 6c1def7..c810eef 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -31,7 +31,7 @@ parser.add_argument('--seed',
                     type = int, default = 0)
 
 parser.add_argument('--nb_epochs',
-                    type = int, default = 100)
+                    type = int, default = -1)
 
 parser.add_argument('--batch_size',
                     type = int, default = 25)
@@ -113,6 +113,25 @@ for n in vars(args):
 
 ######################################################################
 
+def produce_results(
+        self,
+        model, nb_samples, nb_tokens_to_generate, starting_input = None,
+        device = 'cpu'
+):
+    results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
+    for input in results.split(self.batch_size):
+        for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+            output = model(input)
+            logits = output[:, s]
+            if args.synthesis_sampling:
+                dist = torch.distributions.categorical.Categorical(logits = logits)
+                t = dist.sample()
+            else:
+                t = logits.argmax(1)
+            input[:, s + 1] = t
+
+######################################################################
+
 class Task:
     def batches(self, split = 'train'):
         pass
@@ -356,7 +375,7 @@ class TaskMNIST(Task):
     def produce_results(self, n_epoch, model, nb_samples = 64):
         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
         for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+            for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
                 output = model(input)
                 logits = output[:, s]
                 if args.synthesis_sampling:
@@ -364,7 +383,7 @@ class TaskMNIST(Task):
                     t = dist.sample()
                 else:
                     t = logits.argmax(1)
-                input[:, s + 1] = t
+                input[:, s] = t
 
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
@@ -373,26 +392,16 @@ class TaskMNIST(Task):
 
 ######################################################################
 
-def check_causality(model):
-    #m = model[1:]
-    input = torch.rand(1, 5, dim_model).requires_grad_()
-    output = m(input)
-    a = torch.zeros(output.size(1), input.size(1))
-    for k in range(output.size(1)):
-        for d in range(output.size(2)):
-            g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
-            a[k] += g.squeeze(0).pow(2).sum(1)
-    print(a)
-
-######################################################################
-
 log_string(f'device {device}')
 
 if args.data == 'wiki103':
+    nb_epochs_default = 10
     task = TaskWiki103(batch_size = args.batch_size, device = device)
 elif args.data == 'mnist':
+    nb_epochs_default = 25
     task = TaskMNIST(batch_size = args.batch_size, device = device)
 elif args.data == 'picoclvr':
+    nb_epochs_default = 10
     task = TaskPicoCLVR(batch_size = args.batch_size,
                         height = args.picoclvr_height,
                         width = args.picoclvr_width,
@@ -453,7 +462,17 @@ else:
 
 ######################################################################
 
-for k in range(nb_epochs_finished, args.nb_epochs):
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+token_count = 0
+for input in task.batches(split = 'train'):
+    token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+h = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(h)
+log_string(f'Train set perplexity {train_set_perplexity}')
+
+for k in range(nb_epochs_finished, nb_epochs):
 
     model.train()
 
@@ -462,7 +481,7 @@ for k in range(nb_epochs_finished, args.nb_epochs):
     for input in task.batches(split = 'train'):
         input = input.to(device)
         output = model(input)
-        loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+        loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
 
@@ -486,7 +505,7 @@ for k in range(nb_epochs_finished, args.nb_epochs):
         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
 
-        log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
+        log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
 
         task.produce_results(k, model)
 
index 37fe6af..7f0c9e6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -127,10 +127,11 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
+        x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        return x
+        return x[:, :-1]
 
 ######################################################################
 
index 40a442f..74e7e9e 100644 (file)
@@ -1,4 +1,8 @@
 
+To run the MNIST experiment:
+
+  ./main.py --data=mnist
+
 To run the picoclvr experiment:
 
-  ./main.py --data=picoclvr --nb_epochs=8
+  ./main.py --data=picoclvr