Initial commit
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 ######################################################################
15
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
18 ######################################################################
19
20 parser = argparse.ArgumentParser(description = 'My own GPT.')
21
22 parser.add_argument('--log_filename',
23                     type = str, default = 'train.log')
24
25 parser.add_argument('--download',
26                     type = bool, default = False)
27
28 parser.add_argument('--seed',
29                     type = int, default = 0)
30
31 parser.add_argument('--nb_epochs',
32                     type = int, default = 100)
33
34 parser.add_argument('--batch_size',
35                     type = int, default = 25)
36
37 parser.add_argument('--data',
38                     type = str, default = 'wiki103')
39
40 parser.add_argument('--data_size',
41                     type = int, default = -1)
42
43 parser.add_argument('--optim',
44                     type = str, default = 'adam')
45
46 parser.add_argument('--learning_rate',
47                     type = float, default = 1e-4)
48
49 parser.add_argument('--dim_model',
50                     type = int, default = 512)
51
52 parser.add_argument('--dim_keys',
53                     type = int, default = 64)
54
55 parser.add_argument('--dim_hidden',
56                     type = int, default = 2048)
57
58 parser.add_argument('--nb_heads',
59                     type = int, default = 8)
60
61 parser.add_argument('--nb_blocks',
62                     type = int, default = 12)
63
64 parser.add_argument('--dropout',
65                     type = float, default = 0.1)
66
67 parser.add_argument('--synthesis_sampling',
68                     type = bool, default = True)
69
70 ######################################################################
71
72 args = parser.parse_args()
73
74 log_file = open(args.log_filename, 'w')
75
76 if args.seed >= 0:
77     torch.manual_seed(args.seed)
78
79 ######################################################################
80
81 def log_string(s):
82     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
83
84     if log_file is not None:
85         log_file.write(t + s + '\n')
86         log_file.flush()
87
88     print(t + s)
89     sys.stdout.flush()
90
91 for n in vars(args):
92     log_string(f'args.{n} {getattr(args, n)}')
93
94 ##############################
95
96 class Residual(nn.Module):
97     def __init__(self, *f):
98         super().__init__()
99         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
100
101     def forward(self, x):
102         return x + self.f(x)
103
104 ##############################
105
106 class PositionalEncoding(nn.Module):
107     def __init__(self, len_max):
108         super().__init__()
109         self.len_max = len_max
110
111     # From Vaswani et al 2018
112     # PE_{t,2i}   = sin(t/(L^{2i/D}))
113     # PE_{t,2i+1} = cos(t/(L^{2i/D}))
114     def forward(self, x):
115         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
116         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
117         k = j%2
118         return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
119
120 ##############################
121
122 class QKVAttention(nn.Module):
123     def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
124         super().__init__()
125
126         def randw(*d):
127             return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
128
129         self.wq = randw(nb_heads, dim_qk, dim_in)
130         self.wk = randw(nb_heads, dim_qk, dim_in)
131         self.wv = randw(nb_heads, dim_v, dim_in)
132         self.causal = causal
133         self.attention_dropout = attention_dropout
134
135     def forward(self, x):
136         q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
137         k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
138         v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
139         r = math.sqrt(q.size(3))
140         a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
141         if self.causal:
142             mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
143             a = a.masked_fill(mask, float('-inf'))
144         a = a.softmax(dim = 3)
145         a = F.dropout(a, self.attention_dropout, self.training)
146         y = torch.einsum('nhts,nhsd->nhtd', a, v)
147         return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
148
149 ##############################
150
151 class MyGPT(nn.Module):
152     def __init__(self,
153                  vocabulary_size,
154                  dim_model, dim_keys, dim_hidden,
155                  nb_heads, nb_blocks, dropout = 0.):
156
157         super().__init__()
158
159         assert dim_model % nb_heads == 0
160
161         self.embedding = nn.Sequential(
162             nn.Embedding(vocabulary_size, dim_model),
163             nn.Dropout(dropout),
164             PositionalEncoding(len_max = 1e5),
165         )
166
167         trunk_blocks = [ ]
168
169         for _ in range(nb_blocks):
170             trunk_blocks += [
171                 Residual(
172                     nn.LayerNorm(dim_model),
173                     QKVAttention(
174                         dim_in = dim_model,
175                         dim_qk = dim_keys, dim_v = dim_model // nb_heads,
176                         nb_heads = nb_heads,
177                         causal = True, attention_dropout = dropout
178                     ),
179                     nn.Linear(in_features = dim_model, out_features = dim_model),
180                 ),
181                 Residual(
182                     nn.LayerNorm(dim_model),
183                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
184                     nn.ReLU(),
185                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
186                     nn.Dropout(dropout),
187                 ),
188             ]
189
190         self.trunk = nn.Sequential(*trunk_blocks)
191
192         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
193
194     def forward(self, x):
195         x = self.embedding(x)
196         x = self.trunk(x)
197         x = self.readout(x)
198         return x
199
200 ######################################################################
201
202 class Task:
203     def batches(self, split = 'train'):
204         pass
205
206     def vocabulary_size(self):
207         pass
208
209     def produce_results(self, n_epoch, model, nb_tokens = 50):
210         pass
211
212 ######################################################################
213
214 import picoclvr
215
216 class TaskPicoCLVR(Task):
217
218     def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
219         self.batch_size = batch_size
220         self.device = device
221         nb = args.data_size if args.data_size > 0 else 250000
222
223         descr = picoclvr.generate(nb, height = height, width = width)
224         descr = [ s.strip().split(' ') for s in descr ]
225         l = max([ len(s) for s in descr ])
226         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
227
228         tokens = set()
229         for s in descr:
230             for t in s: tokens.add(t)
231         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
232         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
233
234         t = [ [ self.token2id[u] for u in s ] for s in descr ]
235         data_input = torch.tensor(t, device = self.device)
236
237         self.test_input = data_input[:nb // 5]
238         self.train_input = data_input[nb // 5:]
239
240     def batches(self, split = 'train'):
241         assert split in { 'train', 'test' }
242         if split == 'train':
243             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
244                 yield batch
245         else:
246             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
247                 yield batch
248
249     def vocabulary_size(self):
250         return len(self.token2id)
251
252     def produce_results(self, n_epoch, model, nb_tokens = 50):
253         img = [ ]
254         nb_per_primer = 8
255         for primer in [
256                 'red above green <sep> green top <sep> blue right of red <img>',
257                 'there is red <sep> there is yellow <sep> there is blue <img>',
258                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
259                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
260         ]:
261
262             for k in range(nb_per_primer):
263                 t_primer = primer.strip().split(' ')
264                 t_generated = [ ]
265
266                 for j in range(nb_tokens):
267                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
268                     input = torch.tensor(t, device = self.device)
269                     output = model(input)
270                     logits = output[0, -1]
271                     if args.synthesis_sampling:
272                         dist = torch.distributions.categorical.Categorical(logits = logits)
273                         t = dist.sample()
274                     else:
275                         t = logits.argmax()
276                     t_generated.append(self.id2token[t.item()])
277
278                 descr = [ ' '.join(t_primer + t_generated) ]
279                 img += [ picoclvr.descr2img(descr) ]
280
281         img = torch.cat(img, 0)
282         file_name = f'result_picoclvr_{n_epoch:04d}.png'
283         torchvision.utils.save_image(img / 255.,
284                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
285         log_string(f'wrote {file_name}')
286
287 ######################################################################
288
289 class TaskWiki103(Task):
290
291     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
292                  device = torch.device('cpu')):
293
294         self.batch_size = batch_size
295         self.len_min = len_min
296         self.len_max = len_max
297         self.min_freq = min_freq
298         self.device = device
299
300         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
301         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
302
303         # Mostly for debug
304         if args.data_size > 0:
305             train_iter = itertools.islice(train_iter, args.data_size)
306
307         def yield_tokens():
308             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
309                 yield self.tokenizer(l)
310
311         self.vocab = torchtext.vocab.build_vocab_from_iterator(
312             yield_tokens(),
313             specials = [ '<unk>', '<non>' ],
314             min_freq = self.min_freq
315         )
316
317         self.vocab.set_default_index(self.vocab[ '<unk>' ])
318
319     def tensorize(self, s):
320         a = max(len(x) for x in s)
321         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
322
323     def yield_batches(self, ds):
324         s = [ ]
325         for l in ds:
326             q = self.tokenizer(l)
327             if len(q) >= self.len_min and len(q) <= self.len_max:
328                 s += [ q ]
329                 if len(s) == self.batch_size:
330                     yield self.tensorize(s)
331                     s = [ ]
332
333         if len(s) > 0:
334             yield self.tensorize(s)
335
336     def batches(self, split = 'train'):
337         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
338
339         # Mostly for debug
340         if args.data_size > 0:
341             data_iter = itertools.islice(data_iter, args.data_size)
342
343         return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
344
345     def vocabulary_size(self):
346         return len(self.vocab)
347
348     def produce_results(self, n_epoch, model, nb_tokens = 50):
349         file_name = f'result_wiki103_{n_epoch:04d}.txt'
350
351         with open(file_name, 'w') as outfile:
352              for primer in [
353                      'the cat is hunting a',
354                      'paris is the capital',
355                      'cars are convenient',
356                      'the difference between men and women is',
357                      'the object was blue all over and green all over it was',
358                      'cherries are red and lemons are',
359                      'cherries are sweet and lemons are',
360                      'two plus three equals',
361                      'deep learning is',
362              ]:
363                  t_primer = self.tokenizer(primer)
364                  t_generated = [ ]
365
366                  for j in range(nb_tokens):
367
368                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
369                      output = model(input)
370                      logits = output[0, -1]
371                      if args.synthesis_sampling:
372                          dist = torch.distributions.categorical.Categorical(logits = logits)
373                          t = dist.sample()
374                      else:
375                          t = logits.argmax()
376                      t_generated.append(self.vocab.lookup_token(t))
377                      if t_generated[-1] == '<non>': break
378
379                  s = ' '.join(t_generated)
380
381                  outfile.write(f'<{primer}> {s}\n')
382
383         log_string(f'wrote {file_name}')
384
385 ######################################################################
386
387 class TaskMNIST(Task):
388
389     def __init__(self, batch_size, device = torch.device('cpu')):
390         self.device = device
391         self.batch_size = batch_size
392
393     def batches(self, split = 'train'):
394         assert split in { 'train', 'test' }
395         data_set = torchvision.datasets.MNIST(
396             root = './data', train = (split == 'train'),
397             download = True
398         )
399         data_input = data_set.data.view(-1, 28 * 28).long()
400         if args.data_size >= 0:
401             data_input = data_input[:args.data_size]
402         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
403             yield batch
404
405     def vocabulary_size(self):
406         return 256
407
408     def produce_results(self, n_epoch, model, nb_samples = 64):
409         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
410         for input in results.split(self.batch_size):
411             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
412                 output = model(input)
413                 logits = output[:, s]
414                 if args.synthesis_sampling:
415                     dist = torch.distributions.categorical.Categorical(logits = logits)
416                     t = dist.sample()
417                 else:
418                     t = logits.argmax(1)
419                 input[:, s + 1] = t
420
421         image_name = f'result_mnist_{n_epoch:04d}.png'
422         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
423                                      image_name, nrow = 16, pad_value = 0.8)
424         log_string(f'wrote {image_name}')
425
426 ######################################################################
427
428 def check_causality(model):
429     #m = model[1:]
430     input = torch.rand(1, 5, dim_model).requires_grad_()
431     output = m(input)
432     a = torch.zeros(output.size(1), input.size(1))
433     for k in range(output.size(1)):
434         for d in range(output.size(2)):
435             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
436             a[k] += g.squeeze(0).pow(2).sum(1)
437     print(a)
438
439 ######################################################################
440
441 log_string(f'device {device}')
442
443 if args.data == 'wiki103':
444     task = TaskWiki103(batch_size = args.batch_size, device = device)
445 elif args.data == 'mnist':
446     task = TaskMNIST(batch_size = args.batch_size, device = device)
447 elif args.data == 'picoclvr':
448     task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
449 else:
450     raise ValueError(f'Unknown dataset {args.data}.')
451
452 vocabulary_size = task.vocabulary_size()
453
454 log_string(f'vocabulary_size {vocabulary_size}')
455
456 ##############################
457
458 model = MyGPT(
459     vocabulary_size = vocabulary_size,
460     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
461     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
462 )
463
464 nb_parameters = sum(p.numel() for p in model.parameters())
465 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
466
467 model.to(device)
468
469 ######################################################################
470
471 if args.optim == 'sgd':
472     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
473 elif args.optim == 'adam':
474     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
475 elif args.optim == 'adamw':
476     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
477 else:
478     raise ValueError(f'Unknown optimizer {args.optim}.')
479
480 for k in range(args.nb_epochs):
481
482     model.train()
483
484     nb_train_samples, acc_train_loss = 0, 0.0
485
486     for input in task.batches(split = 'train'):
487         input = input.to(device)
488         output = model(input)
489         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
490         acc_train_loss += loss.item() * input.size(0)
491         nb_train_samples += input.size(0)
492
493         optimizer.zero_grad()
494         loss.backward()
495         optimizer.step()
496
497     with torch.autograd.no_grad():
498
499         model.eval()
500
501         nb_test_samples, acc_test_loss = 0, 0.0
502
503         for input in task.batches(split = 'test'):
504             input = input.to(device)
505             output = model(input)
506             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
507             acc_test_loss += loss.item() * input.size(0)
508             nb_test_samples += input.size(0)
509
510         log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
511
512         task.produce_results(k, model)
513
514 ######################################################################