From fc570d4ccd5d5dee36271d34ff5c672a50a82101 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 12:35:19 +0200 Subject: [PATCH] Moved the input/output shift in the forward of the model. --- main.py | 57 ++++++++++++++++++++++++++++++++++++------------------ mygpt.py | 3 ++- readme.txt | 6 +++++- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 6c1def7..c810eef 100755 --- a/main.py +++ b/main.py @@ -31,7 +31,7 @@ parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--nb_epochs', - type = int, default = 100) + type = int, default = -1) parser.add_argument('--batch_size', type = int, default = 25) @@ -113,6 +113,25 @@ for n in vars(args): ###################################################################### +def produce_results( + self, + model, nb_samples, nb_tokens_to_generate, starting_input = None, + device = 'cpu' +): + results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device) + for input in results.split(self.batch_size): + for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax(1) + input[:, s + 1] = t + +###################################################################### + class Task: def batches(self, split = 'train'): pass @@ -356,7 +375,7 @@ class TaskMNIST(Task): def produce_results(self, n_epoch, model, nb_samples = 64): results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] if args.synthesis_sampling: @@ -364,7 +383,7 @@ class TaskMNIST(Task): t = dist.sample() else: t = logits.argmax(1) - input[:, s + 1] = t + input[:, s] = t image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., @@ -373,26 +392,16 @@ class TaskMNIST(Task): ###################################################################### -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - log_string(f'device {device}') if args.data == 'wiki103': + nb_epochs_default = 10 task = TaskWiki103(batch_size = args.batch_size, device = device) elif args.data == 'mnist': + nb_epochs_default = 25 task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': + nb_epochs_default = 10 task = TaskPicoCLVR(batch_size = args.batch_size, height = args.picoclvr_height, width = args.picoclvr_width, @@ -453,7 +462,17 @@ else: ###################################################################### -for k in range(nb_epochs_finished, args.nb_epochs): +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split = 'train'): + token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +h = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(h) +log_string(f'Train set perplexity {train_set_perplexity}') + +for k in range(nb_epochs_finished, nb_epochs): model.train() @@ -462,7 +481,7 @@ for k in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split = 'train'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -486,7 +505,7 @@ for k in range(nb_epochs_finished, args.nb_epochs): train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}') task.produce_results(k, model) diff --git a/mygpt.py b/mygpt.py index 37fe6af..7f0c9e6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -127,10 +127,11 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = torch.cat((x.new_zeros(x.size(0), 1), x), 1) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - return x + return x[:, :-1] ###################################################################### diff --git a/readme.txt b/readme.txt index 40a442f..74e7e9e 100644 --- a/readme.txt +++ b/readme.txt @@ -1,4 +1,8 @@ +To run the MNIST experiment: + + ./main.py --data=mnist + To run the picoclvr experiment: - ./main.py --data=picoclvr --nb_epochs=8 + ./main.py --data=picoclvr -- 2.20.1