From: Francois Fleuret Date: Wed, 27 Jul 2022 09:18:06 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=8ea0e3c5cc303718a8b508b656f7aa9e64ea3070 Update. --- diff --git a/main.py b/main.py index 6592204..b2adf98 100755 --- a/main.py +++ b/main.py @@ -111,7 +111,7 @@ for n in vars(args): ###################################################################### def autoregression( - model, + model, batch_size, nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): @@ -126,7 +126,7 @@ def autoregression( first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(args.batch_size): + for input in results.split(batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] @@ -386,7 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = autoregression(model, nb_samples, 28 * 28, device = self.device) + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) diff --git a/result_picoclvr_0007.png b/result_picoclvr_0007.png deleted file mode 100644 index 7baee57..0000000 Binary files a/result_picoclvr_0007.png and /dev/null differ diff --git a/result_picoclvr_0009.png b/result_picoclvr_0009.png new file mode 100644 index 0000000..18dad27 Binary files /dev/null and b/result_picoclvr_0009.png differ