Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_many_colors',
79                     action='store_true', default = False)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 class Task:
114     def batches(self, split = 'train'):
115         pass
116
117     def vocabulary_size(self):
118         pass
119
120     def produce_results(self, n_epoch, model, nb_tokens = 50):
121         pass
122
123 ######################################################################
124
125 import picoclvr
126
127 class TaskPicoCLVR(Task):
128
129     def __init__(self, batch_size,
130                  height, width, many_colors = False,
131                  device = torch.device('cpu')):
132
133         def generate_descr(nb):
134             descr = picoclvr.generate(
135                 nb,
136                 height = self.height, width = self.width,
137                 many_colors = many_colors
138             )
139
140             descr = [ s.strip().split(' ') for s in descr ]
141             l = max([ len(s) for s in descr ])
142             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
143
144             return descr
145
146         self.height = height
147         self.width = width
148         self.batch_size = batch_size
149         self.device = device
150         nb = args.data_size if args.data_size > 0 else 250000
151
152         self.train_descr = generate_descr((nb * 4) // 5)
153         self.test_descr = generate_descr((nb * 1) // 5)
154
155         tokens = set()
156         for d in [ self.train_descr, self.test_descr ]:
157             for s in d:
158                 for t in s: tokens.add(t)
159         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
160         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
161
162         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
163         self.train_input = torch.tensor(t, device = self.device)
164         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
165         self.test_input = torch.tensor(t, device = self.device)
166
167     def batches(self, split = 'train'):
168         assert split in { 'train', 'test' }
169         if split == 'train':
170             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
171                 yield batch
172         else:
173             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
174                 yield batch
175
176     def vocabulary_size(self):
177         return len(self.token2id)
178
179     def generate(self, primer, model, nb_tokens):
180         t_primer = primer.strip().split(' ')
181         t_generated = [ ]
182
183         for j in range(nb_tokens):
184             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
185             input = torch.tensor(t, device = self.device)
186             output = model(input)
187             logits = output[0, -1]
188             if args.synthesis_sampling:
189                 dist = torch.distributions.categorical.Categorical(logits = logits)
190                 t = dist.sample()
191             else:
192                 t = logits.argmax()
193             t_generated.append(self.id2token[t.item()])
194
195         return ' '.join(t_primer + t_generated)
196
197     def produce_results(self, n_epoch, model, nb_tokens = None):
198         if nb_tokens is None:
199             nb_tokens = self.height * self.width + 3
200         descr = [ ]
201         nb_per_primer = 8
202
203         for primer in [
204                 'red above green <sep> green top <sep> blue right of red <img>',
205                 'there is red <sep> there is yellow <sep> there is blue <img>',
206                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
207                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
208         ]:
209
210             for k in range(nb_per_primer):
211                 descr.append(self.generate(primer, model, nb_tokens))
212
213         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
214         img = torch.cat(img, 0)
215         file_name = f'result_picoclvr_{n_epoch:04d}.png'
216         torchvision.utils.save_image(
217             img / 255.,
218             file_name, nrow = nb_per_primer, pad_value = 0.8
219         )
220         log_string(f'wrote {file_name}')
221
222         nb_missing = sum( [
223             x[2] for x in picoclvr.nb_missing_properties(
224                 descr,
225                 height = self.height, width = self.width
226             )
227         ] )
228
229         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
230
231 ######################################################################
232
233 class TaskWiki103(Task):
234
235     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
236                  device = torch.device('cpu')):
237
238         self.batch_size = batch_size
239         self.len_min = len_min
240         self.len_max = len_max
241         self.min_freq = min_freq
242         self.device = device
243
244         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
245         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
246
247         # Mostly for debug
248         if args.data_size > 0:
249             train_iter = itertools.islice(train_iter, args.data_size)
250
251         def yield_tokens():
252             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
253                 yield self.tokenizer(l)
254
255         self.vocab = torchtext.vocab.build_vocab_from_iterator(
256             yield_tokens(),
257             specials = [ '<unk>', '<non>' ],
258             min_freq = self.min_freq
259         )
260
261         self.vocab.set_default_index(self.vocab[ '<unk>' ])
262
263     def tensorize(self, s):
264         a = max(len(x) for x in s)
265         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
266
267     def yield_batches(self, ds):
268         s = [ ]
269         for l in ds:
270             q = self.tokenizer(l)
271             if len(q) >= self.len_min and len(q) <= self.len_max:
272                 s += [ q ]
273                 if len(s) == self.batch_size:
274                     yield self.tensorize(s)
275                     s = [ ]
276
277         if len(s) > 0:
278             yield self.tensorize(s)
279
280     def batches(self, split = 'train'):
281         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
282
283         # Mostly for debug
284         if args.data_size > 0:
285             data_iter = itertools.islice(data_iter, args.data_size)
286
287         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
288
289     def vocabulary_size(self):
290         return len(self.vocab)
291
292     def produce_results(self, n_epoch, model, nb_tokens = 50):
293         file_name = f'result_wiki103_{n_epoch:04d}.txt'
294
295         with open(file_name, 'w') as outfile:
296              for primer in [
297                      'the cat is hunting a',
298                      'paris is the capital',
299                      'cars are convenient',
300                      'the difference between men and women is',
301                      'the object was blue all over and green all over it was',
302                      'cherries are red and lemons are',
303                      'cherries are sweet and lemons are',
304                      'two plus three equals',
305                      'deep learning is',
306              ]:
307                  t_primer = self.tokenizer(primer)
308                  t_generated = [ ]
309
310                  for j in range(nb_tokens):
311
312                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
313                      output = model(input)
314                      logits = output[0, -1]
315                      if args.synthesis_sampling:
316                          dist = torch.distributions.categorical.Categorical(logits = logits)
317                          t = dist.sample()
318                      else:
319                          t = logits.argmax()
320                      t_generated.append(self.vocab.lookup_token(t))
321                      if t_generated[-1] == '<non>': break
322
323                  s = ' '.join(t_generated)
324
325                  outfile.write(f'<{primer}> {s}\n')
326
327         log_string(f'wrote {file_name}')
328
329 ######################################################################
330
331 class TaskMNIST(Task):
332
333     def __init__(self, batch_size, device = torch.device('cpu')):
334         self.device = device
335         self.batch_size = batch_size
336
337     def batches(self, split = 'train'):
338         assert split in { 'train', 'test' }
339         data_set = torchvision.datasets.MNIST(
340             root = './data', train = (split == 'train'),
341             download = True
342         )
343         data_input = data_set.data.view(-1, 28 * 28).long()
344         if args.data_size >= 0:
345             data_input = data_input[:args.data_size]
346         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
347             yield batch
348
349     def vocabulary_size(self):
350         return 256
351
352     def produce_results(self, n_epoch, model, nb_samples = 64):
353         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
354         for input in results.split(self.batch_size):
355             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
356                 output = model(input)
357                 logits = output[:, s]
358                 if args.synthesis_sampling:
359                     dist = torch.distributions.categorical.Categorical(logits = logits)
360                     t = dist.sample()
361                 else:
362                     t = logits.argmax(1)
363                 input[:, s + 1] = t
364
365         image_name = f'result_mnist_{n_epoch:04d}.png'
366         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
367                                      image_name, nrow = 16, pad_value = 0.8)
368         log_string(f'wrote {image_name}')
369
370 ######################################################################
371
372 def check_causality(model):
373     #m = model[1:]
374     input = torch.rand(1, 5, dim_model).requires_grad_()
375     output = m(input)
376     a = torch.zeros(output.size(1), input.size(1))
377     for k in range(output.size(1)):
378         for d in range(output.size(2)):
379             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
380             a[k] += g.squeeze(0).pow(2).sum(1)
381     print(a)
382
383 ######################################################################
384
385 log_string(f'device {device}')
386
387 if args.data == 'wiki103':
388     task = TaskWiki103(batch_size = args.batch_size, device = device)
389 elif args.data == 'mnist':
390     task = TaskMNIST(batch_size = args.batch_size, device = device)
391 elif args.data == 'picoclvr':
392     task = TaskPicoCLVR(batch_size = args.batch_size,
393                         height = args.picoclvr_height,
394                         width = args.picoclvr_width,
395                         many_colors = args.picoclvr_many_colors,
396                         device = device)
397 else:
398     raise ValueError(f'Unknown dataset {args.data}.')
399
400 vocabulary_size = task.vocabulary_size()
401
402 log_string(f'vocabulary_size {vocabulary_size}')
403
404 ##############################
405
406 model = mygpt.MyGPT(
407     vocabulary_size = vocabulary_size,
408     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
409     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
410 )
411
412 model.to(device)
413
414 nb_parameters = sum(p.numel() for p in model.parameters())
415 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
416
417 ######################################################################
418
419 if args.optim == 'sgd':
420     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
421 elif args.optim == 'adam':
422     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
423 elif args.optim == 'adamw':
424     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
425 else:
426     raise ValueError(f'Unknown optimizer {args.optim}.')
427
428 ######################################################################
429
430 nb_epochs_finished = 0
431
432 try:
433     checkpoint = torch.load(args.checkpoint_name, map_location = device)
434     nb_epochs_finished = checkpoint['nb_epochs_finished']
435     model.load_state_dict(checkpoint['model_state'])
436     optimizer.load_state_dict(checkpoint['optimizer_state'])
437     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
438
439 except FileNotFoundError:
440     print('Starting from scratch.')
441
442 except:
443     print('Error when loading the checkpoint.')
444     exit(1)
445
446 ######################################################################
447
448 for k in range(nb_epochs_finished, args.nb_epochs):
449
450     model.train()
451
452     nb_train_samples, acc_train_loss = 0, 0.0
453
454     for input in task.batches(split = 'train'):
455         input = input.to(device)
456         output = model(input)
457         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
458         acc_train_loss += loss.item() * input.size(0)
459         nb_train_samples += input.size(0)
460
461         optimizer.zero_grad()
462         loss.backward()
463         optimizer.step()
464
465     with torch.autograd.no_grad():
466
467         model.eval()
468
469         nb_test_samples, acc_test_loss = 0, 0.0
470
471         for input in task.batches(split = 'test'):
472             input = input.to(device)
473             output = model(input)
474             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
475             acc_test_loss += loss.item() * input.size(0)
476             nb_test_samples += input.size(0)
477
478         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
479         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
480
481         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
482
483         task.produce_results(k, model)
484
485     checkpoint = {
486         'nb_epochs_finished': k + 1,
487         'model_state': model.state_dict(),
488         'optimizer_state': optimizer.state_dict()
489     }
490
491     torch.save(checkpoint, args.checkpoint_name)
492
493 ######################################################################