X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=83227bb43a24897b506f0b82035a69dc33f7acfe;hb=b6a9cc237cdadac2351814f92c20607d46b0f583;hp=a6940f1dde279cb2800c606e56bba63311b1fe1c;hpb=95d8b6bc41a753f7a12b2a4cd047ea11cdc2054f;p=mygpt.git diff --git a/main.py b/main.py index a6940f1..83227bb 100755 --- a/main.py +++ b/main.py @@ -18,20 +18,16 @@ import mygpt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### - parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - type = bool, default = False) - parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--nb_epochs', - type = int, default = 100) + type = int, default = -1) parser.add_argument('--batch_size', type = int, default = 25) @@ -67,7 +63,25 @@ parser.add_argument('--dropout', type = float, default = 0.1) parser.add_argument('--synthesis_sampling', - type = bool, default = True) + action='store_true', default = True) + +parser.add_argument('--no_checkpoint', + action='store_true', default = False) + +parser.add_argument('--checkpoint_name', + type = str, default = 'checkpoint.pth') + +############################## +# picoclvr options + +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) + +parser.add_argument('--picoclvr_height', + type = int, default = 12) + +parser.add_argument('--picoclvr_width', + type = int, default = 16) ###################################################################### @@ -95,6 +109,37 @@ for n in vars(args): ###################################################################### +def autoregression( + model, batch_size, + nb_samples, nb_tokens_to_generate, primer = None, + device = torch.device('cpu') +): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if primer is None: + first = 0 + else: + first = primer.size(1) + results = torch.cat((primer, results), 1) + + for input in results.split(batch_size): + for s in range(first, input.size(1)): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + else: + t_next = logits.argmax(1) + input[:, s] = t_next + + return results + +###################################################################### + class Task: def batches(self, split = 'train'): pass @@ -102,7 +147,7 @@ class Task: def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass ###################################################################### @@ -111,75 +156,107 @@ import picoclvr class TaskPicoCLVR(Task): - def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in token_descr ]) + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] + return torch.tensor(id_descr, device = self.device) + + def trim(self, x, token = ''): + n = self.token2id[token] + i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return x[:, a:b] + + def __init__(self, batch_size, + height, width, nb_colors = 5, + device = torch.device('cpu')): + + def generate_descr(nb): + return picoclvr.generate( + nb, + height = self.height, width = self.width, + nb_colors = nb_colors + ) + + self.height = height + self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate(nb, height = height, width = width) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + self.train_descr = generate_descr((nb * 4) // 5) + self.test_descr = generate_descr((nb * 1) // 5) - tokens = set() - for s in descr: - for t in s: tokens.add(t) + # Build the tokenizer + tokens = { '' } + for d in [ self.train_descr, self.test_descr ]: + for s in d: + for t in s.strip().split(' '): tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) def batches(self, split = 'train'): assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): - yield batch + input = self.train_input if split == 'train' else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) - def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] + def produce_results(self, n_epoch, model): + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [ ] nb_per_primer = 8 - for primer in [ + for primer_descr in [ 'red above green green top blue right of red ', 'there is red there is yellow there is blue ', 'red below yellow yellow below green green below blue red right yellow left green right blue left ', 'green bottom yellow bottom green left of blue yellow right of blue blue top ', ]: - for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] + results = autoregression( + model, + self.batch_size, + nb_samples = nb_per_primer, + nb_tokens_to_generate = nb_tokens_to_generate, + primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1), + device = self.device + ) + + l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ] + result_descr += l + + np = picoclvr.nb_properties( + result_descr, + height = self.height, width = self.width + ) + + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') + + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') ###################################################################### @@ -207,15 +284,16 @@ class TaskWiki103(Task): self.vocab = torchtext.vocab.build_vocab_from_iterator( yield_tokens(), - specials = [ '', '' ], + specials = [ '', '' ], min_freq = self.min_freq ) self.vocab.set_default_index(self.vocab[ '' ]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) def yield_batches(self, ds): s = [ ] @@ -237,12 +315,13 @@ class TaskWiki103(Task): if args.data_size > 0: data_iter = itertools.islice(data_iter, args.data_size) - return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) + return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}')) def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): + nb_tokens = 50 file_name = f'result_wiki103_{n_epoch:04d}.txt' with open(file_name, 'w') as outfile: @@ -263,15 +342,16 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) + if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -296,25 +376,15 @@ class TaskMNIST(Task): data_input = data_set.data.view(-1, 28 * 28).long() if args.data_size >= 0: data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): + for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'): yield batch def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - + def produce_results(self, n_epoch, model): + nb_samples = 64 + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -322,27 +392,21 @@ class TaskMNIST(Task): ###################################################################### -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - log_string(f'device {device}') if args.data == 'wiki103': + nb_epochs_default = 10 task = TaskWiki103(batch_size = args.batch_size, device = device) elif args.data == 'mnist': + nb_epochs_default = 25 task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) + nb_epochs_default = 10 + task = TaskPicoCLVR(batch_size = args.batch_size, + height = args.picoclvr_height, + width = args.picoclvr_width, + nb_colors = args.picoclvr_nb_colors, + device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -358,11 +422,11 @@ model = mygpt.MyGPT( nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout ) +model.to(device) + nb_parameters = sum(p.numel() for p in model.parameters()) log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') -model.to(device) - ###################################################################### if args.optim == 'sgd': @@ -374,7 +438,40 @@ elif args.optim == 'adamw': else: raise ValueError(f'Unknown optimizer {args.optim}.') -for k in range(args.nb_epochs): +###################################################################### + +nb_epochs_finished = 0 + +if args.no_checkpoint: + log_string(f'not trying to load checkpoint.') + +else: + try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') + + except FileNotFoundError: + log_string('starting from scratch.') + + except: + log_string('error when loading the checkpoint.') + exit(1) + +###################################################################### + +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split = 'train'): + token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +for k in range(nb_epochs_finished, nb_epochs): model.train() @@ -383,7 +480,7 @@ for k in range(args.nb_epochs): for input in task.batches(split = 'train'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -400,15 +497,23 @@ for k in range(args.nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') task.produce_results(k, model) + checkpoint = { + 'nb_epochs_finished': k + 1, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(checkpoint, args.checkpoint_name) + ######################################################################