X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=77b1b226f95c4d9a11dd6fe5990d575ae0795399;hb=b4c255babaae72d6a03b4c8e8e7e25f6ab0a19a0;hp=a18beb1ecc97ea456f3ea24ea1b518f05379db65;hpb=73507ec9ff7677eefdafabe4123c9b05cab28f8f;p=mygpt.git diff --git a/main.py b/main.py index a18beb1..77b1b22 100755 --- a/main.py +++ b/main.py @@ -24,14 +24,11 @@ parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--nb_epochs', - type = int, default = 100) + type = int, default = -1) parser.add_argument('--batch_size', type = int, default = 25) @@ -69,11 +66,23 @@ parser.add_argument('--dropout', parser.add_argument('--synthesis_sampling', action='store_true', default = True) +parser.add_argument('--no_checkpoint', + action='store_true', default = False) + parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +############################## +# picoclvr options + +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) + +parser.add_argument('--picoclvr_height', + type = int, default = 12) + +parser.add_argument('--picoclvr_width', + type = int, default = 16) ###################################################################### @@ -101,6 +110,37 @@ for n in vars(args): ###################################################################### +def autoregression( + model, batch_size, + nb_samples, nb_tokens_to_generate, primer = None, + device = torch.device('cpu') +): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if primer is None: + first = 0 + else: + first = primer.size(1) + results = torch.cat((primer, results), 1) + + for input in results.split(batch_size): + for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + else: + t_next = logits.argmax(1) + input[:, s] = t_next + + return results + +###################################################################### + class Task: def batches(self, split = 'train'): pass @@ -117,38 +157,48 @@ import picoclvr class TaskPicoCLVR(Task): + def descr2tensor(self, descr): + t = [ [ self.token2id[u] for u in s ] for s in descr ] + return torch.tensor(t, device = self.device) + def __init__(self, batch_size, - height = 6, width = 8, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): + def generate_descr(nb): + descr = picoclvr.generate( + nb, + height = self.height, width = self.width, + nb_colors = nb_colors + ) + + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + return descr + + self.height = height + self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate( - nb, - height = height, width = width, - many_colors = many_colors - ) - - self.test_descr = descr[:nb // 5] - self.train_descr = descr[nb // 5:] - - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + self.train_descr = generate_descr((nb * 4) // 5) + self.test_descr = generate_descr((nb * 1) // 5) + # Build the tokenizer tokens = set() - for s in descr: - for t in s: tokens.add(t) + for d in [ self.train_descr, self.test_descr ]: + for s in d: + for t in s: tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] + # Tokenize the train and test sets + self.train_input = descr2tensor(self.train_descr) + self.test_input = descr2tensor(self.test_descr) def batches(self, split = 'train'): assert split in { 'train', 'test' } @@ -162,29 +212,21 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def generate(self, primer, model, nb_tokens): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - return ' '.join(t_primer + t_generated) + def generate(self, primer_descr, model, nb_tokens): + results = autoregression( + model, self.batch_size, + 1, nb_tokens, primer = descr2tensor(primer_descr), + device = self.device + ) + return ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) - def produce_results(self, n_epoch, model, nb_tokens = 50): - descr = [ ] + def produce_results(self, n_epoch, model, nb_tokens = None): + if nb_tokens is None: + nb_tokens = self.height * self.width + 3 + result_descr = [ ] nb_per_primer = 8 - for primer in [ + for primer_descr in [ 'red above green green top blue right of red ', 'there is red there is yellow there is blue ', 'red below yellow yellow below green green below blue red right yellow left green right blue left ', @@ -192,16 +234,26 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - descr.append(self.generate(primer, model, nb_tokens)) + result_descr.append(self.generate(primer_descr, model, nb_tokens)) - img = [ picoclvr.descr2img(d) for d in descr ] + img = [ picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr ] img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') + + np = picoclvr.nb_properties( + result_descr, + height = self.height, width = self.width + ) - log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}') + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') ###################################################################### @@ -285,14 +337,15 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -325,18 +378,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -344,27 +386,21 @@ class TaskMNIST(Task): ###################################################################### -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - log_string(f'device {device}') if args.data == 'wiki103': + nb_epochs_default = 10 task = TaskWiki103(batch_size = args.batch_size, device = device) elif args.data == 'mnist': + nb_epochs_default = 25 task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) + nb_epochs_default = 10 + task = TaskPicoCLVR(batch_size = args.batch_size, + height = args.picoclvr_height, + width = args.picoclvr_width, + nb_colors = args.picoclvr_nb_colors, + device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -400,23 +436,37 @@ else: nb_epochs_finished = 0 -try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) - print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') +if args.no_checkpoint: + log_string(f'not trying to load checkpoint.') + +else: + try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') -except FileNotFoundError: - print('Starting from scratch.') + except FileNotFoundError: + log_string('starting from scratch.') -except: - print('Error when loading the checkpoint.') - exit(1) + except: + log_string('error when loading the checkpoint.') + exit(1) ###################################################################### -for k in range(nb_epochs_finished, args.nb_epochs): +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split = 'train'): + token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) +#log_string(f'train set perplexity {train_set_perplexity}') + +for k in range(nb_epochs_finished, nb_epochs): model.train() @@ -425,7 +475,7 @@ for k in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split = 'train'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -442,14 +492,14 @@ for k in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') task.produce_results(k, model)