X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=954f4f088c29116052f3258dd13eda2f5a6b3b0c;hb=f082ce9255f87b6c1bfebfac00f94820a16e04f1;hp=970ee7b528f0f3f91bd777b760b48140c343f678;hpb=68c17359790a9b8ac931a3679f08ad6a82a4e640;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 970ee7b..954f4f0 100755 --- a/mygpt.py +++ b/mygpt.py @@ -5,95 +5,16 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, itertools +import math + +import torch -import torch, torchtext, torchvision from torch import nn from torch.nn import functional as F -###################################################################### - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -###################################################################### - -parser = argparse.ArgumentParser(description = 'My own GPT.') - -parser.add_argument('--log_filename', - type = str, default = 'train.log') - -parser.add_argument('--download', - type = bool, default = False) - -parser.add_argument('--seed', - type = int, default = 0) - -parser.add_argument('--nb_epochs', - type = int, default = 100) - -parser.add_argument('--batch_size', - type = int, default = 25) - -parser.add_argument('--data', - type = str, default = 'wiki103') - -parser.add_argument('--data_size', - type = int, default = -1) - -parser.add_argument('--optim', - type = str, default = 'adam') - -parser.add_argument('--learning_rate', - type = float, default = 1e-4) - -parser.add_argument('--dim_model', - type = int, default = 512) - -parser.add_argument('--dim_keys', - type = int, default = 64) - -parser.add_argument('--dim_hidden', - type = int, default = 2048) - -parser.add_argument('--nb_heads', - type = int, default = 8) - -parser.add_argument('--nb_blocks', - type = int, default = 12) - -parser.add_argument('--dropout', - type = float, default = 0.1) - -parser.add_argument('--synthesis_sampling', - type = bool, default = True) - -###################################################################### - -args = parser.parse_args() - -log_file = open(args.log_filename, 'w') - -if args.seed >= 0: - torch.manual_seed(args.seed) - -###################################################################### - -def log_string(s): - t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) - - if log_file is not None: - log_file.write(t + s + '\n') - log_file.flush() - - print(t + s) - sys.stdout.flush() - -for n in vars(args): - log_string(f'args.{n} {getattr(args, n)}') - ############################## -class Residual(nn.Module): +class WithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -103,48 +24,59 @@ class Residual(nn.Module): ############################## -class PositionalEncoding(nn.Module): +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 - return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + return x + pe ############################## class QKVAttention(nn.Module): - def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__(self, + dim_in, dim_qk, dim_v, + nb_heads = 1, causal = False, attention_dropout = 0.0): super().__init__() def randw(*d): - return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.wq = randw(nb_heads, dim_qk, dim_in) - self.wk = randw(nb_heads, dim_qk, dim_in) - self.wv = randw(nb_heads, dim_v, dim_in) self.causal = causal self.attention_dropout = attention_dropout - def forward(self, x): - q = torch.einsum('ntc,hdc->nhtd', x, self.wq) - k = torch.einsum('ntc,hdc->nhtd', x, self.wk) - v = torch.einsum('ntc,hdc->nhtd', x, self.wv) - r = math.sqrt(q.size(3)) - a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r) + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def forward(self, x_q, x_kv = None): + if x_kv is None: x_kv = x_q + + q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) + k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) + v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + + a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + if self.causal: - mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0 + mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ + < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd) + y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + + y = y @ self.w_o + + return y ############################## @@ -152,7 +84,8 @@ class MyGPT(nn.Module): def __init__(self, vocabulary_size, dim_model, dim_keys, dim_hidden, - nb_heads, nb_blocks, dropout = 0.): + nb_heads, nb_blocks, + dropout = 0.0, len_max = 1e5): super().__init__() @@ -161,25 +94,25 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max = 1e5), + AddPositionalEncoding(len_max), ) trunk_blocks = [ ] for _ in range(nb_blocks): trunk_blocks += [ - Residual( - nn.LayerNorm(dim_model), + WithResidual( + nn.LayerNorm((dim_model,)), QKVAttention( dim_in = dim_model, - dim_qk = dim_keys, dim_v = dim_model // nb_heads, + dim_qk = dim_keys, + dim_v = dim_model // nb_heads, nb_heads = nb_heads, causal = True, attention_dropout = dropout ), - nn.Linear(in_features = dim_model, out_features = dim_model), ), - Residual( - nn.LayerNorm(dim_model), + WithResidual( + nn.LayerNorm((dim_model,)), nn.Linear(in_features = dim_model, out_features = dim_hidden), nn.ReLU(), nn.Linear(in_features = dim_hidden, out_features = dim_model), @@ -192,323 +125,28 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) + x = F.pad(x, (0, 0, 0, -1)) return x ###################################################################### -class Task: - def batches(self, split = 'train'): - pass - - def vocabulary_size(self): - pass - - def produce_results(self, n_epoch, model, nb_tokens = 50): - pass - -###################################################################### - -import picoclvr - -class TaskPicoCLVR(Task): - - def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): - self.batch_size = batch_size - self.device = device - nb = args.data_size if args.data_size > 0 else 250000 - - descr = picoclvr.generate(nb, height = height, width = width) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - tokens = set() - for s in descr: - for t in s: tokens.add(t) - self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) - self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] - - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): - yield batch - - def vocabulary_size(self): - return len(self.token2id) - - def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] - nb_per_primer = 8 - for primer in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: - - for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] - - img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') - -###################################################################### - -class TaskWiki103(Task): - - def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100, - device = torch.device('cpu')): - - self.batch_size = batch_size - self.len_min = len_min - self.len_max = len_max - self.min_freq = min_freq - self.device = device - - self.tokenizer = torchtext.data.get_tokenizer('basic_english') - train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/') - - # Mostly for debug - if args.data_size > 0: - train_iter = itertools.islice(train_iter, args.data_size) - - def yield_tokens(): - for l in tqdm.tqdm(train_iter, desc = 'vocab'): - yield self.tokenizer(l) - - self.vocab = torchtext.vocab.build_vocab_from_iterator( - yield_tokens(), - specials = [ '', '' ], - min_freq = self.min_freq - ) - - self.vocab.set_default_index(self.vocab[ '' ]) - - def tensorize(self, s): - a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) - - def yield_batches(self, ds): - s = [ ] - for l in ds: - q = self.tokenizer(l) - if len(q) >= self.len_min and len(q) <= self.len_max: - s += [ q ] - if len(s) == self.batch_size: - yield self.tensorize(s) - s = [ ] - - if len(s) > 0: - yield self.tensorize(s) - - def batches(self, split = 'train'): - data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') - - # Mostly for debug - if args.data_size > 0: - data_iter = itertools.islice(data_iter, args.data_size) - - return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) - - def vocabulary_size(self): - return len(self.vocab) - - def produce_results(self, n_epoch, model, nb_tokens = 50): - file_name = f'result_wiki103_{n_epoch:04d}.txt' - - with open(file_name, 'w') as outfile: - for primer in [ - 'the cat is hunting a', - 'paris is the capital', - 'cars are convenient', - 'the difference between men and women is', - 'the object was blue all over and green all over it was', - 'cherries are red and lemons are', - 'cherries are sweet and lemons are', - 'two plus three equals', - 'deep learning is', - ]: - t_primer = self.tokenizer(primer) - t_generated = [ ] - - for j in range(nb_tokens): - - input = self.tensorize([ t_primer + t_generated ]).to(self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break - - s = ' '.join(t_generated) - - outfile.write(f'<{primer}> {s}\n') - - log_string(f'wrote {file_name}') - -###################################################################### - -class TaskMNIST(Task): - - def __init__(self, batch_size, device = torch.device('cpu')): - self.device = device - self.batch_size = batch_size - - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - data_set = torchvision.datasets.MNIST( - root = './data', train = (split == 'train'), - download = True - ) - data_input = data_set.data.view(-1, 28 * 28).long() - if args.data_size >= 0: - data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): - yield batch - - def vocabulary_size(self): - return 256 - - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - - image_name = f'result_mnist_{n_epoch:04d}.png' - torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., - image_name, nrow = 16, pad_value = 0.8) - log_string(f'wrote {image_name}') - -###################################################################### - -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - -log_string(f'device {device}') - -if args.data == 'wiki103': - task = TaskWiki103(batch_size = args.batch_size, device = device) -elif args.data == 'mnist': - task = TaskMNIST(batch_size = args.batch_size, device = device) -elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) -else: - raise ValueError(f'Unknown dataset {args.data}.') - -vocabulary_size = task.vocabulary_size() - -log_string(f'vocabulary_size {vocabulary_size}') - -############################## - -model = MyGPT( - vocabulary_size = vocabulary_size, - dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, - nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout -) - -nb_parameters = sum(p.numel() for p in model.parameters()) -log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') - -model.to(device) - -###################################################################### - -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -for k in range(args.nb_epochs): - - model.train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in task.batches(split = 'train'): - input = input.to(device) - output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - with torch.autograd.no_grad(): - - model.eval() - - nb_test_samples, acc_test_loss = 0, 0.0 +if __name__ == '__main__': + print('Basic check.') - for input in task.batches(split = 'test'): - input = input.to(device) - output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) + vocabulary_size = 10 + x = torch.randint(vocabulary_size, (25, 100)) - log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}') + model = MyGPT( + vocabulary_size = vocabulary_size, + dim_model = 18, dim_keys = 50, dim_hidden = 100, + nb_heads = 2, nb_blocks = 3, + dropout = 0.1 + ) - task.produce_results(k, model) + y = model(x) ######################################################################