X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=f65bb8eafde0e667d23f0d12746a480383bf85ff;hb=f082ce9255f87b6c1bfebfac00f94820a16e04f1;hp=c810eef06593c7939e0ad15fe690225509cd4150;hpb=fc570d4ccd5d5dee36271d34ff5c672a50a82101;p=mygpt.git diff --git a/main.py b/main.py index c810eef..f65bb8e 100755 --- a/main.py +++ b/main.py @@ -18,15 +18,11 @@ import mygpt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### - parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) @@ -78,8 +74,8 @@ parser.add_argument('--checkpoint_name', ############################## # picoclvr options -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) parser.add_argument('--picoclvr_height', type = int, default = 12) @@ -113,22 +109,34 @@ for n in vars(args): ###################################################################### -def produce_results( - self, - model, nb_samples, nb_tokens_to_generate, starting_input = None, - device = 'cpu' +def autoregression( + model, batch_size, + nb_samples, nb_tokens_to_generate, primer = None, + device = torch.device('cpu') ): - results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if primer is None: + first = 0 + else: + first = primer.size(1) + results = torch.cat((primer, results), 1) + + for input in results.split(batch_size): + for s in range(first, input.size(1)): output = model(input) logits = output[:, s] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax(1) - input[:, s + 1] = t + t_next = logits.argmax(1) + input[:, s] = t_next + + return results ###################################################################### @@ -139,7 +147,7 @@ class Task: def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass ###################################################################### @@ -148,92 +156,101 @@ import picoclvr class TaskPicoCLVR(Task): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in token_descr ]) + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] + return torch.tensor(id_descr, device = self.device) + + def trim(self, x, token = ''): + n = self.token2id[token] + i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return x[:, a:b] + def __init__(self, batch_size, - height, width, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): def generate_descr(nb): - descr = picoclvr.generate( + return picoclvr.generate( nb, height = self.height, width = self.width, - many_colors = many_colors + nb_colors = nb_colors ) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - return descr - self.height = height self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 + log_string('generating {nb} samples (can take some time)') self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) # Build the tokenizer - tokens = set() + tokens = { '' } for d in [ self.train_descr, self.test_descr ]: for s in d: - for t in s: tokens.add(t) + for t in s.strip().split(' '): tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] - self.train_input = torch.tensor(t, device = self.device) - t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] - self.test_input = torch.tensor(t, device = self.device) + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) def batches(self, split = 'train'): assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + input = self.train_input if split == 'train' else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) - def generate(self, primer, model, nb_tokens): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - return ' '.join(t_primer + t_generated) - - def produce_results(self, n_epoch, model, nb_tokens = None): - if nb_tokens is None: - nb_tokens = self.height * self.width + 3 - descr = [ ] + def produce_results(self, n_epoch, model): + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [ ] nb_per_primer = 8 - for primer in [ + for primer_descr in [ 'red above green green top blue right of red ', 'there is red there is yellow there is blue ', 'red below yellow yellow below green green below blue red right yellow left green right blue left ', 'green bottom yellow bottom green left of blue yellow right of blue blue top ', ]: - for k in range(nb_per_primer): - descr.append(self.generate(primer, model, nb_tokens)) + results = autoregression( + model, + self.batch_size, + nb_samples = nb_per_primer, + nb_tokens_to_generate = nb_tokens_to_generate, + primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1), + device = self.device + ) + + l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ] + result_descr += l + + np = picoclvr.nb_properties( + result_descr, + height = self.height, width = self.width + ) + + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') + + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] - img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] img = torch.cat(img, 0) image_name = f'result_picoclvr_{n_epoch:04d}.png' torchvision.utils.save_image( @@ -242,15 +259,6 @@ class TaskPicoCLVR(Task): ) log_string(f'wrote {image_name}') - nb_missing = sum( [ - x[2] for x in picoclvr.nb_missing_properties( - descr, - height = self.height, width = self.width - ) - ] ) - - log_string(f'nb_missing {nb_missing / len(descr):.02f}') - ###################################################################### class TaskWiki103(Task): @@ -277,15 +285,16 @@ class TaskWiki103(Task): self.vocab = torchtext.vocab.build_vocab_from_iterator( yield_tokens(), - specials = [ '', '' ], + specials = [ '', '' ], min_freq = self.min_freq ) self.vocab.set_default_index(self.vocab[ '' ]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) def yield_batches(self, ds): s = [ ] @@ -312,7 +321,8 @@ class TaskWiki103(Task): def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): + nb_tokens = 50 file_name = f'result_wiki103_{n_epoch:04d}.txt' with open(file_name, 'w') as outfile: @@ -333,15 +343,16 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) + if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -372,19 +383,9 @@ class TaskMNIST(Task): def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s] = t - + def produce_results(self, n_epoch, model): + nb_samples = 64 + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -405,7 +406,7 @@ elif args.data == 'picoclvr': task = TaskPicoCLVR(batch_size = args.batch_size, height = args.picoclvr_height, width = args.picoclvr_width, - many_colors = args.picoclvr_many_colors, + nb_colors = args.picoclvr_nb_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -443,7 +444,7 @@ else: nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'Not trying to load checkpoint.') + log_string(f'not trying to load checkpoint.') else: try: @@ -451,13 +452,13 @@ else: nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) - log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: - log_string('Starting from scratch.') + log_string('starting from scratch.') except: - log_string('Error when loading the checkpoint.') + log_string('error when loading the checkpoint.') exit(1) ###################################################################### @@ -468,9 +469,8 @@ token_count = 0 for input in task.batches(split = 'train'): token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) token_probas = token_count / token_count.sum() -h = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(h) -log_string(f'Train set perplexity {train_set_perplexity}') +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) for k in range(nb_epochs_finished, nb_epochs): @@ -498,14 +498,14 @@ for k in range(nb_epochs_finished, nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') task.produce_results(k, model)