Update.
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 ######################################################################
15
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
18 ######################################################################
19
20 parser = argparse.ArgumentParser(description = 'My own GPT.')
21
22 parser.add_argument('--log_filename',
23                     type = str, default = 'train.log')
24
25 parser.add_argument('--download',
26                     type = bool, default = False)
27
28 parser.add_argument('--seed',
29                     type = int, default = 0)
30
31 parser.add_argument('--nb_epochs',
32                     type = int, default = 100)
33
34 parser.add_argument('--batch_size',
35                     type = int, default = 25)
36
37 parser.add_argument('--data',
38                     type = str, default = 'wiki103')
39
40 parser.add_argument('--data_size',
41                     type = int, default = -1)
42
43 parser.add_argument('--optim',
44                     type = str, default = 'adam')
45
46 parser.add_argument('--learning_rate',
47                     type = float, default = 1e-4)
48
49 parser.add_argument('--dim_model',
50                     type = int, default = 512)
51
52 parser.add_argument('--dim_keys',
53                     type = int, default = 64)
54
55 parser.add_argument('--dim_hidden',
56                     type = int, default = 2048)
57
58 parser.add_argument('--nb_heads',
59                     type = int, default = 8)
60
61 parser.add_argument('--nb_blocks',
62                     type = int, default = 12)
63
64 parser.add_argument('--dropout',
65                     type = float, default = 0.1)
66
67 parser.add_argument('--synthesis_sampling',
68                     type = bool, default = True)
69
70 ######################################################################
71
72 args = parser.parse_args()
73
74 log_file = open(args.log_filename, 'w')
75
76 if args.seed >= 0:
77     torch.manual_seed(args.seed)
78
79 ######################################################################
80
81 def log_string(s):
82     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
83
84     if log_file is not None:
85         log_file.write(t + s + '\n')
86         log_file.flush()
87
88     print(t + s)
89     sys.stdout.flush()
90
91 for n in vars(args):
92     log_string(f'args.{n} {getattr(args, n)}')
93
94 ##############################
95
96 class Residual(nn.Module):
97     def __init__(self, *f):
98         super().__init__()
99         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
100
101     def forward(self, x):
102         return x + self.f(x)
103
104 ##############################
105
106 class PositionalEncoding(nn.Module):
107     def __init__(self, len_max):
108         super().__init__()
109         self.len_max = len_max
110
111     # From Vaswani et al 2018
112     # PE_{t,2i}   = sin(t/(L^{2i/D}))
113     # PE_{t,2i+1} = cos(t/(L^{2i/D}))
114     def forward(self, x):
115         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
116         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
117         k = j%2
118         return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
119
120 ##############################
121
122 class QKVAttention(nn.Module):
123     def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
124         super().__init__()
125
126         def randw(*d):
127             return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
128
129         self.wq = randw(nb_heads, dim_qk, dim_in)
130         self.wk = randw(nb_heads, dim_qk, dim_in)
131         self.wv = randw(nb_heads, dim_v, dim_in)
132         self.causal = causal
133         self.attention_dropout = attention_dropout
134
135     def forward(self, x):
136         q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
137         k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
138         v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
139         r = math.sqrt(q.size(3))
140         a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
141         if self.causal:
142             mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
143             a = a.masked_fill(mask, float('-inf'))
144         a = a.softmax(dim = 3)
145         a = F.dropout(a, self.attention_dropout, self.training)
146         y = torch.einsum('nhts,nhsd->nhtd', a, v)
147         return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
148
149 ##############################
150
151 class MyGPT(nn.Module):
152     def __init__(self,
153                  vocabulary_size,
154                  dim_model, dim_keys, dim_hidden,
155                  nb_heads, nb_blocks, dropout = 0.):
156
157         super().__init__()
158
159         assert dim_model % nb_heads == 0
160
161         self.embedding = nn.Sequential(
162             nn.Embedding(vocabulary_size, dim_model),
163             nn.Dropout(dropout),
164             PositionalEncoding(len_max = 1e5),
165         )
166
167         trunk_blocks = [ ]
168
169         for _ in range(nb_blocks):
170             trunk_blocks += [
171                 Residual(
172                     nn.LayerNorm(dim_model),
173                     QKVAttention(
174                         dim_in = dim_model,
175                         dim_qk = dim_keys, dim_v = dim_model // nb_heads,
176                         nb_heads = nb_heads,
177                         causal = True, attention_dropout = dropout
178                     ),
179                     nn.Linear(in_features = dim_model, out_features = dim_model),
180                 ),
181                 Residual(
182                     nn.LayerNorm(dim_model),
183                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
184                     nn.ReLU(),
185                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
186                     nn.Dropout(dropout),
187                 ),
188             ]
189
190         self.trunk = nn.Sequential(*trunk_blocks)
191
192         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
193
194     def forward(self, x):
195         x = self.embedding(x)
196         x = self.trunk(x)
197         x = self.readout(x)
198         return x
199
200 ######################################################################
201
202 class Task:
203     def batches(self, split = 'train'):
204         pass
205
206     def vocabulary_size(self):
207         pass
208
209     def produce_results(self, n_epoch, model, nb_tokens = 50):
210         pass
211
212 ######################################################################
213
214 import picoclvr
215
216 class TaskPicoCLVR(Task):
217
218     def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
219         self.batch_size = batch_size
220         self.device = device
221         nb = args.data_size if args.data_size > 0 else 250000
222
223         descr = picoclvr.generate(nb, height = height, width = width)
224         descr = [ s.strip().split(' ') for s in descr ]
225         l = max([ len(s) for s in descr ])
226         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
227
228         tokens = set()
229         for s in descr:
230             for t in s: tokens.add(t)
231         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
232         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
233
234         t = [ [ self.token2id[u] for u in s ] for s in descr ]
235         data_input = torch.tensor(t, device = self.device)
236
237         self.test_input = data_input[:nb // 5]
238         self.train_input = data_input[nb // 5:]
239
240     def batches(self, split = 'train'):
241         assert split in { 'train', 'test' }
242         if split == 'train':
243             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
244                 yield batch
245         else:
246             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
247                 yield batch
248
249     def vocabulary_size(self):
250         return len(self.token2id)
251
252     def produce_results(self, n_epoch, model, nb_tokens = 50):
253         img = [ ]
254         nb_per_primer = 8
255
256         for primer in [
257                 'red above green <sep> green top <sep> blue right of red <img>',
258                 'there is red <sep> there is yellow <sep> there is blue <img>',
259                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
260                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
261         ]:
262
263             for k in range(nb_per_primer):
264                 t_primer = primer.strip().split(' ')
265                 t_generated = [ ]
266
267                 for j in range(nb_tokens):
268                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
269                     input = torch.tensor(t, device = self.device)
270                     output = model(input)
271                     logits = output[0, -1]
272                     if args.synthesis_sampling:
273                         dist = torch.distributions.categorical.Categorical(logits = logits)
274                         t = dist.sample()
275                     else:
276                         t = logits.argmax()
277                     t_generated.append(self.id2token[t.item()])
278
279                 descr = [ ' '.join(t_primer + t_generated) ]
280                 img += [ picoclvr.descr2img(descr) ]
281
282         img = torch.cat(img, 0)
283         file_name = f'result_picoclvr_{n_epoch:04d}.png'
284         torchvision.utils.save_image(img / 255.,
285                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
286         log_string(f'wrote {file_name}')
287
288 ######################################################################
289
290 class TaskWiki103(Task):
291
292     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
293                  device = torch.device('cpu')):
294
295         self.batch_size = batch_size
296         self.len_min = len_min
297         self.len_max = len_max
298         self.min_freq = min_freq
299         self.device = device
300
301         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
302         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
303
304         # Mostly for debug
305         if args.data_size > 0:
306             train_iter = itertools.islice(train_iter, args.data_size)
307
308         def yield_tokens():
309             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
310                 yield self.tokenizer(l)
311
312         self.vocab = torchtext.vocab.build_vocab_from_iterator(
313             yield_tokens(),
314             specials = [ '<unk>', '<non>' ],
315             min_freq = self.min_freq
316         )
317
318         self.vocab.set_default_index(self.vocab[ '<unk>' ])
319
320     def tensorize(self, s):
321         a = max(len(x) for x in s)
322         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
323
324     def yield_batches(self, ds):
325         s = [ ]
326         for l in ds:
327             q = self.tokenizer(l)
328             if len(q) >= self.len_min and len(q) <= self.len_max:
329                 s += [ q ]
330                 if len(s) == self.batch_size:
331                     yield self.tensorize(s)
332                     s = [ ]
333
334         if len(s) > 0:
335             yield self.tensorize(s)
336
337     def batches(self, split = 'train'):
338         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
339
340         # Mostly for debug
341         if args.data_size > 0:
342             data_iter = itertools.islice(data_iter, args.data_size)
343
344         return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
345
346     def vocabulary_size(self):
347         return len(self.vocab)
348
349     def produce_results(self, n_epoch, model, nb_tokens = 50):
350         file_name = f'result_wiki103_{n_epoch:04d}.txt'
351
352         with open(file_name, 'w') as outfile:
353              for primer in [
354                      'the cat is hunting a',
355                      'paris is the capital',
356                      'cars are convenient',
357                      'the difference between men and women is',
358                      'the object was blue all over and green all over it was',
359                      'cherries are red and lemons are',
360                      'cherries are sweet and lemons are',
361                      'two plus three equals',
362                      'deep learning is',
363              ]:
364                  t_primer = self.tokenizer(primer)
365                  t_generated = [ ]
366
367                  for j in range(nb_tokens):
368
369                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
370                      output = model(input)
371                      logits = output[0, -1]
372                      if args.synthesis_sampling:
373                          dist = torch.distributions.categorical.Categorical(logits = logits)
374                          t = dist.sample()
375                      else:
376                          t = logits.argmax()
377                      t_generated.append(self.vocab.lookup_token(t))
378                      if t_generated[-1] == '<non>': break
379
380                  s = ' '.join(t_generated)
381
382                  outfile.write(f'<{primer}> {s}\n')
383
384         log_string(f'wrote {file_name}')
385
386 ######################################################################
387
388 class TaskMNIST(Task):
389
390     def __init__(self, batch_size, device = torch.device('cpu')):
391         self.device = device
392         self.batch_size = batch_size
393
394     def batches(self, split = 'train'):
395         assert split in { 'train', 'test' }
396         data_set = torchvision.datasets.MNIST(
397             root = './data', train = (split == 'train'),
398             download = True
399         )
400         data_input = data_set.data.view(-1, 28 * 28).long()
401         if args.data_size >= 0:
402             data_input = data_input[:args.data_size]
403         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
404             yield batch
405
406     def vocabulary_size(self):
407         return 256
408
409     def produce_results(self, n_epoch, model, nb_samples = 64):
410         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
411         for input in results.split(self.batch_size):
412             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
413                 output = model(input)
414                 logits = output[:, s]
415                 if args.synthesis_sampling:
416                     dist = torch.distributions.categorical.Categorical(logits = logits)
417                     t = dist.sample()
418                 else:
419                     t = logits.argmax(1)
420                 input[:, s + 1] = t
421
422         image_name = f'result_mnist_{n_epoch:04d}.png'
423         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
424                                      image_name, nrow = 16, pad_value = 0.8)
425         log_string(f'wrote {image_name}')
426
427 ######################################################################
428
429 def check_causality(model):
430     #m = model[1:]
431     input = torch.rand(1, 5, dim_model).requires_grad_()
432     output = m(input)
433     a = torch.zeros(output.size(1), input.size(1))
434     for k in range(output.size(1)):
435         for d in range(output.size(2)):
436             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
437             a[k] += g.squeeze(0).pow(2).sum(1)
438     print(a)
439
440 ######################################################################
441
442 log_string(f'device {device}')
443
444 if args.data == 'wiki103':
445     task = TaskWiki103(batch_size = args.batch_size, device = device)
446 elif args.data == 'mnist':
447     task = TaskMNIST(batch_size = args.batch_size, device = device)
448 elif args.data == 'picoclvr':
449     task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
450 else:
451     raise ValueError(f'Unknown dataset {args.data}.')
452
453 vocabulary_size = task.vocabulary_size()
454
455 log_string(f'vocabulary_size {vocabulary_size}')
456
457 ##############################
458
459 model = MyGPT(
460     vocabulary_size = vocabulary_size,
461     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
462     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
463 )
464
465 nb_parameters = sum(p.numel() for p in model.parameters())
466 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
467
468 model.to(device)
469
470 ######################################################################
471
472 if args.optim == 'sgd':
473     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
474 elif args.optim == 'adam':
475     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
476 elif args.optim == 'adamw':
477     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
478 else:
479     raise ValueError(f'Unknown optimizer {args.optim}.')
480
481 for k in range(args.nb_epochs):
482
483     model.train()
484
485     nb_train_samples, acc_train_loss = 0, 0.0
486
487     for input in task.batches(split = 'train'):
488         input = input.to(device)
489         output = model(input)
490         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
491         acc_train_loss += loss.item() * input.size(0)
492         nb_train_samples += input.size(0)
493
494         optimizer.zero_grad()
495         loss.backward()
496         optimizer.step()
497
498     with torch.autograd.no_grad():
499
500         model.eval()
501
502         nb_test_samples, acc_test_loss = 0, 0.0
503
504         for input in task.batches(split = 'test'):
505             input = input.to(device)
506             output = model(input)
507             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
508             acc_test_loss += loss.item() * input.size(0)
509             nb_test_samples += input.size(0)
510
511         log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
512
513         task.produce_results(k, model)
514
515 ######################################################################