ace376da96dcdffb2fbceec73044052bbdc99aa9
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_many_colors',
79                     action='store_true', default = False)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 class Task:
114     def batches(self, split = 'train'):
115         pass
116
117     def vocabulary_size(self):
118         pass
119
120     def produce_results(self, n_epoch, model, nb_tokens = 50):
121         pass
122
123 ######################################################################
124
125 import picoclvr
126
127 class TaskPicoCLVR(Task):
128
129     def __init__(self, batch_size,
130                  height, width, many_colors = False,
131                  device = torch.device('cpu')):
132
133         self.height = height
134         self.width = width
135         self.batch_size = batch_size
136         self.device = device
137         nb = args.data_size if args.data_size > 0 else 250000
138
139         descr = picoclvr.generate(
140             nb,
141             height = self.height, width = self.width,
142             many_colors = many_colors
143         )
144
145         # self.test_descr = descr[:nb // 5]
146         # self.train_descr = descr[nb // 5:]
147
148         descr = [ s.strip().split(' ') for s in descr ]
149         l = max([ len(s) for s in descr ])
150         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
151
152         tokens = set()
153         for s in descr:
154             for t in s: tokens.add(t)
155         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
156         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
157
158         t = [ [ self.token2id[u] for u in s ] for s in descr ]
159         data_input = torch.tensor(t, device = self.device)
160
161         self.test_input = data_input[:nb // 5]
162         self.train_input = data_input[nb // 5:]
163
164     def batches(self, split = 'train'):
165         assert split in { 'train', 'test' }
166         if split == 'train':
167             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
168                 yield batch
169         else:
170             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
171                 yield batch
172
173     def vocabulary_size(self):
174         return len(self.token2id)
175
176     def generate(self, primer, model, nb_tokens):
177         t_primer = primer.strip().split(' ')
178         t_generated = [ ]
179
180         for j in range(nb_tokens):
181             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
182             input = torch.tensor(t, device = self.device)
183             output = model(input)
184             logits = output[0, -1]
185             if args.synthesis_sampling:
186                 dist = torch.distributions.categorical.Categorical(logits = logits)
187                 t = dist.sample()
188             else:
189                 t = logits.argmax()
190             t_generated.append(self.id2token[t.item()])
191
192         return ' '.join(t_primer + t_generated)
193
194     def produce_results(self, n_epoch, model, nb_tokens = None):
195         if nb_tokens is None:
196             nb_tokens = self.height * self.width + 3
197         descr = [ ]
198         nb_per_primer = 8
199
200         for primer in [
201                 'red above green <sep> green top <sep> blue right of red <img>',
202                 'there is red <sep> there is yellow <sep> there is blue <img>',
203                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
204                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
205         ]:
206
207             for k in range(nb_per_primer):
208                 descr.append(self.generate(primer, model, nb_tokens))
209
210         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
211         img = torch.cat(img, 0)
212         file_name = f'result_picoclvr_{n_epoch:04d}.png'
213         torchvision.utils.save_image(img / 255.,
214                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
215         log_string(f'wrote {file_name}')
216
217         nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr, height = self.height, width = self.width) ] )
218         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
219
220 ######################################################################
221
222 class TaskWiki103(Task):
223
224     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
225                  device = torch.device('cpu')):
226
227         self.batch_size = batch_size
228         self.len_min = len_min
229         self.len_max = len_max
230         self.min_freq = min_freq
231         self.device = device
232
233         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
234         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
235
236         # Mostly for debug
237         if args.data_size > 0:
238             train_iter = itertools.islice(train_iter, args.data_size)
239
240         def yield_tokens():
241             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
242                 yield self.tokenizer(l)
243
244         self.vocab = torchtext.vocab.build_vocab_from_iterator(
245             yield_tokens(),
246             specials = [ '<unk>', '<non>' ],
247             min_freq = self.min_freq
248         )
249
250         self.vocab.set_default_index(self.vocab[ '<unk>' ])
251
252     def tensorize(self, s):
253         a = max(len(x) for x in s)
254         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
255
256     def yield_batches(self, ds):
257         s = [ ]
258         for l in ds:
259             q = self.tokenizer(l)
260             if len(q) >= self.len_min and len(q) <= self.len_max:
261                 s += [ q ]
262                 if len(s) == self.batch_size:
263                     yield self.tensorize(s)
264                     s = [ ]
265
266         if len(s) > 0:
267             yield self.tensorize(s)
268
269     def batches(self, split = 'train'):
270         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
271
272         # Mostly for debug
273         if args.data_size > 0:
274             data_iter = itertools.islice(data_iter, args.data_size)
275
276         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
277
278     def vocabulary_size(self):
279         return len(self.vocab)
280
281     def produce_results(self, n_epoch, model, nb_tokens = 50):
282         file_name = f'result_wiki103_{n_epoch:04d}.txt'
283
284         with open(file_name, 'w') as outfile:
285              for primer in [
286                      'the cat is hunting a',
287                      'paris is the capital',
288                      'cars are convenient',
289                      'the difference between men and women is',
290                      'the object was blue all over and green all over it was',
291                      'cherries are red and lemons are',
292                      'cherries are sweet and lemons are',
293                      'two plus three equals',
294                      'deep learning is',
295              ]:
296                  t_primer = self.tokenizer(primer)
297                  t_generated = [ ]
298
299                  for j in range(nb_tokens):
300
301                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
302                      output = model(input)
303                      logits = output[0, -1]
304                      if args.synthesis_sampling:
305                          dist = torch.distributions.categorical.Categorical(logits = logits)
306                          t = dist.sample()
307                      else:
308                          t = logits.argmax()
309                      t_generated.append(self.vocab.lookup_token(t))
310                      if t_generated[-1] == '<non>': break
311
312                  s = ' '.join(t_generated)
313
314                  outfile.write(f'<{primer}> {s}\n')
315
316         log_string(f'wrote {file_name}')
317
318 ######################################################################
319
320 class TaskMNIST(Task):
321
322     def __init__(self, batch_size, device = torch.device('cpu')):
323         self.device = device
324         self.batch_size = batch_size
325
326     def batches(self, split = 'train'):
327         assert split in { 'train', 'test' }
328         data_set = torchvision.datasets.MNIST(
329             root = './data', train = (split == 'train'),
330             download = True
331         )
332         data_input = data_set.data.view(-1, 28 * 28).long()
333         if args.data_size >= 0:
334             data_input = data_input[:args.data_size]
335         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
336             yield batch
337
338     def vocabulary_size(self):
339         return 256
340
341     def produce_results(self, n_epoch, model, nb_samples = 64):
342         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
343         for input in results.split(self.batch_size):
344             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
345                 output = model(input)
346                 logits = output[:, s]
347                 if args.synthesis_sampling:
348                     dist = torch.distributions.categorical.Categorical(logits = logits)
349                     t = dist.sample()
350                 else:
351                     t = logits.argmax(1)
352                 input[:, s + 1] = t
353
354         image_name = f'result_mnist_{n_epoch:04d}.png'
355         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
356                                      image_name, nrow = 16, pad_value = 0.8)
357         log_string(f'wrote {image_name}')
358
359 ######################################################################
360
361 def check_causality(model):
362     #m = model[1:]
363     input = torch.rand(1, 5, dim_model).requires_grad_()
364     output = m(input)
365     a = torch.zeros(output.size(1), input.size(1))
366     for k in range(output.size(1)):
367         for d in range(output.size(2)):
368             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
369             a[k] += g.squeeze(0).pow(2).sum(1)
370     print(a)
371
372 ######################################################################
373
374 log_string(f'device {device}')
375
376 if args.data == 'wiki103':
377     task = TaskWiki103(batch_size = args.batch_size, device = device)
378 elif args.data == 'mnist':
379     task = TaskMNIST(batch_size = args.batch_size, device = device)
380 elif args.data == 'picoclvr':
381     task = TaskPicoCLVR(batch_size = args.batch_size,
382                         height = args.picoclvr_height,
383                         width = args.picoclvr_width,
384                         many_colors = args.picoclvr_many_colors,
385                         device = device)
386 else:
387     raise ValueError(f'Unknown dataset {args.data}.')
388
389 vocabulary_size = task.vocabulary_size()
390
391 log_string(f'vocabulary_size {vocabulary_size}')
392
393 ##############################
394
395 model = mygpt.MyGPT(
396     vocabulary_size = vocabulary_size,
397     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
398     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
399 )
400
401 model.to(device)
402
403 nb_parameters = sum(p.numel() for p in model.parameters())
404 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
405
406 ######################################################################
407
408 if args.optim == 'sgd':
409     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
410 elif args.optim == 'adam':
411     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
412 elif args.optim == 'adamw':
413     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
414 else:
415     raise ValueError(f'Unknown optimizer {args.optim}.')
416
417 ######################################################################
418
419 nb_epochs_finished = 0
420
421 try:
422     checkpoint = torch.load(args.checkpoint_name, map_location = device)
423     nb_epochs_finished = checkpoint['nb_epochs_finished']
424     model.load_state_dict(checkpoint['model_state'])
425     optimizer.load_state_dict(checkpoint['optimizer_state'])
426     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
427
428 except FileNotFoundError:
429     print('Starting from scratch.')
430
431 except:
432     print('Error when loading the checkpoint.')
433     exit(1)
434
435 ######################################################################
436
437 for k in range(nb_epochs_finished, args.nb_epochs):
438
439     model.train()
440
441     nb_train_samples, acc_train_loss = 0, 0.0
442
443     for input in task.batches(split = 'train'):
444         input = input.to(device)
445         output = model(input)
446         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
447         acc_train_loss += loss.item() * input.size(0)
448         nb_train_samples += input.size(0)
449
450         optimizer.zero_grad()
451         loss.backward()
452         optimizer.step()
453
454     with torch.autograd.no_grad():
455
456         model.eval()
457
458         nb_test_samples, acc_test_loss = 0, 0.0
459
460         for input in task.batches(split = 'test'):
461             input = input.to(device)
462             output = model(input)
463             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
464             acc_test_loss += loss.item() * input.size(0)
465             nb_test_samples += input.size(0)
466
467         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
468         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
469
470         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
471
472         task.produce_results(k, model)
473
474     checkpoint = {
475         'nb_epochs_finished': k + 1,
476         'model_state': model.state_dict(),
477         'optimizer_state': optimizer.state_dict()
478     }
479
480     torch.save(checkpoint, args.checkpoint_name)
481
482 ######################################################################