Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_many_colors',
82                     action='store_true', default = False)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 class Task:
117     def batches(self, split = 'train'):
118         pass
119
120     def vocabulary_size(self):
121         pass
122
123     def produce_results(self, n_epoch, model, nb_tokens = 50):
124         pass
125
126 ######################################################################
127
128 import picoclvr
129
130 class TaskPicoCLVR(Task):
131
132     def __init__(self, batch_size,
133                  height, width, many_colors = False,
134                  device = torch.device('cpu')):
135
136         def generate_descr(nb):
137             descr = picoclvr.generate(
138                 nb,
139                 height = self.height, width = self.width,
140                 many_colors = many_colors
141             )
142
143             descr = [ s.strip().split(' ') for s in descr ]
144             l = max([ len(s) for s in descr ])
145             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
146
147             return descr
148
149         self.height = height
150         self.width = width
151         self.batch_size = batch_size
152         self.device = device
153         nb = args.data_size if args.data_size > 0 else 250000
154
155         self.train_descr = generate_descr((nb * 4) // 5)
156         self.test_descr = generate_descr((nb * 1) // 5)
157
158         # Build the tokenizer
159         tokens = set()
160         for d in [ self.train_descr, self.test_descr ]:
161             for s in d:
162                 for t in s: tokens.add(t)
163         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
164         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
165
166         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
167         self.train_input = torch.tensor(t, device = self.device)
168         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
169         self.test_input = torch.tensor(t, device = self.device)
170
171     def batches(self, split = 'train'):
172         assert split in { 'train', 'test' }
173         if split == 'train':
174             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
175                 yield batch
176         else:
177             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
178                 yield batch
179
180     def vocabulary_size(self):
181         return len(self.token2id)
182
183     def generate(self, primer, model, nb_tokens):
184         t_primer = primer.strip().split(' ')
185         t_generated = [ ]
186
187         for j in range(nb_tokens):
188             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
189             input = torch.tensor(t, device = self.device)
190             output = model(input)
191             logits = output[0, -1]
192             if args.synthesis_sampling:
193                 dist = torch.distributions.categorical.Categorical(logits = logits)
194                 t = dist.sample()
195             else:
196                 t = logits.argmax()
197             t_generated.append(self.id2token[t.item()])
198
199         return ' '.join(t_primer + t_generated)
200
201     def produce_results(self, n_epoch, model, nb_tokens = None):
202         if nb_tokens is None:
203             nb_tokens = self.height * self.width + 3
204         descr = [ ]
205         nb_per_primer = 8
206
207         for primer in [
208                 'red above green <sep> green top <sep> blue right of red <img>',
209                 'there is red <sep> there is yellow <sep> there is blue <img>',
210                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
211                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
212         ]:
213
214             for k in range(nb_per_primer):
215                 descr.append(self.generate(primer, model, nb_tokens))
216
217         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
218         img = torch.cat(img, 0)
219         image_name = f'result_picoclvr_{n_epoch:04d}.png'
220         torchvision.utils.save_image(
221             img / 255.,
222             image_name, nrow = nb_per_primer, pad_value = 0.8
223         )
224         log_string(f'wrote {image_name}')
225
226         nb_missing = sum( [
227             x[2] for x in picoclvr.nb_missing_properties(
228                 descr,
229                 height = self.height, width = self.width
230             )
231         ] )
232
233         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
234
235 ######################################################################
236
237 class TaskWiki103(Task):
238
239     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
240                  device = torch.device('cpu')):
241
242         self.batch_size = batch_size
243         self.len_min = len_min
244         self.len_max = len_max
245         self.min_freq = min_freq
246         self.device = device
247
248         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
249         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
250
251         # Mostly for debug
252         if args.data_size > 0:
253             train_iter = itertools.islice(train_iter, args.data_size)
254
255         def yield_tokens():
256             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
257                 yield self.tokenizer(l)
258
259         self.vocab = torchtext.vocab.build_vocab_from_iterator(
260             yield_tokens(),
261             specials = [ '<unk>', '<non>' ],
262             min_freq = self.min_freq
263         )
264
265         self.vocab.set_default_index(self.vocab[ '<unk>' ])
266
267     def tensorize(self, s):
268         a = max(len(x) for x in s)
269         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
270
271     def yield_batches(self, ds):
272         s = [ ]
273         for l in ds:
274             q = self.tokenizer(l)
275             if len(q) >= self.len_min and len(q) <= self.len_max:
276                 s += [ q ]
277                 if len(s) == self.batch_size:
278                     yield self.tensorize(s)
279                     s = [ ]
280
281         if len(s) > 0:
282             yield self.tensorize(s)
283
284     def batches(self, split = 'train'):
285         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
286
287         # Mostly for debug
288         if args.data_size > 0:
289             data_iter = itertools.islice(data_iter, args.data_size)
290
291         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
292
293     def vocabulary_size(self):
294         return len(self.vocab)
295
296     def produce_results(self, n_epoch, model, nb_tokens = 50):
297         file_name = f'result_wiki103_{n_epoch:04d}.txt'
298
299         with open(file_name, 'w') as outfile:
300              for primer in [
301                      'the cat is hunting a',
302                      'paris is the capital',
303                      'cars are convenient',
304                      'the difference between men and women is',
305                      'the object was blue all over and green all over it was',
306                      'cherries are red and lemons are',
307                      'cherries are sweet and lemons are',
308                      'two plus three equals',
309                      'deep learning is',
310              ]:
311                  t_primer = self.tokenizer(primer)
312                  t_generated = [ ]
313
314                  for j in range(nb_tokens):
315
316                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
317                      output = model(input)
318                      logits = output[0, -1]
319                      if args.synthesis_sampling:
320                          dist = torch.distributions.categorical.Categorical(logits = logits)
321                          t = dist.sample()
322                      else:
323                          t = logits.argmax()
324                      t_generated.append(self.vocab.lookup_token(t))
325                      if t_generated[-1] == '<non>': break
326
327                  s = ' '.join(t_generated)
328
329                  outfile.write(f'<{primer}> {s}\n')
330
331         log_string(f'wrote {file_name}')
332
333 ######################################################################
334
335 class TaskMNIST(Task):
336
337     def __init__(self, batch_size, device = torch.device('cpu')):
338         self.device = device
339         self.batch_size = batch_size
340
341     def batches(self, split = 'train'):
342         assert split in { 'train', 'test' }
343         data_set = torchvision.datasets.MNIST(
344             root = './data', train = (split == 'train'),
345             download = True
346         )
347         data_input = data_set.data.view(-1, 28 * 28).long()
348         if args.data_size >= 0:
349             data_input = data_input[:args.data_size]
350         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
351             yield batch
352
353     def vocabulary_size(self):
354         return 256
355
356     def produce_results(self, n_epoch, model, nb_samples = 64):
357         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
358         for input in results.split(self.batch_size):
359             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
360                 output = model(input)
361                 logits = output[:, s]
362                 if args.synthesis_sampling:
363                     dist = torch.distributions.categorical.Categorical(logits = logits)
364                     t = dist.sample()
365                 else:
366                     t = logits.argmax(1)
367                 input[:, s + 1] = t
368
369         image_name = f'result_mnist_{n_epoch:04d}.png'
370         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
371                                      image_name, nrow = 16, pad_value = 0.8)
372         log_string(f'wrote {image_name}')
373
374 ######################################################################
375
376 def check_causality(model):
377     #m = model[1:]
378     input = torch.rand(1, 5, dim_model).requires_grad_()
379     output = m(input)
380     a = torch.zeros(output.size(1), input.size(1))
381     for k in range(output.size(1)):
382         for d in range(output.size(2)):
383             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
384             a[k] += g.squeeze(0).pow(2).sum(1)
385     print(a)
386
387 ######################################################################
388
389 log_string(f'device {device}')
390
391 if args.data == 'wiki103':
392     task = TaskWiki103(batch_size = args.batch_size, device = device)
393 elif args.data == 'mnist':
394     task = TaskMNIST(batch_size = args.batch_size, device = device)
395 elif args.data == 'picoclvr':
396     task = TaskPicoCLVR(batch_size = args.batch_size,
397                         height = args.picoclvr_height,
398                         width = args.picoclvr_width,
399                         many_colors = args.picoclvr_many_colors,
400                         device = device)
401 else:
402     raise ValueError(f'Unknown dataset {args.data}.')
403
404 vocabulary_size = task.vocabulary_size()
405
406 log_string(f'vocabulary_size {vocabulary_size}')
407
408 ##############################
409
410 model = mygpt.MyGPT(
411     vocabulary_size = vocabulary_size,
412     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
413     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
414 )
415
416 model.to(device)
417
418 nb_parameters = sum(p.numel() for p in model.parameters())
419 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
420
421 ######################################################################
422
423 if args.optim == 'sgd':
424     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
425 elif args.optim == 'adam':
426     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adamw':
428     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
429 else:
430     raise ValueError(f'Unknown optimizer {args.optim}.')
431
432 ######################################################################
433
434 nb_epochs_finished = 0
435
436 if args.no_checkpoint:
437     log_string(f'Not trying to load checkpoint.')
438
439 else:
440     try:
441         checkpoint = torch.load(args.checkpoint_name, map_location = device)
442         nb_epochs_finished = checkpoint['nb_epochs_finished']
443         model.load_state_dict(checkpoint['model_state'])
444         optimizer.load_state_dict(checkpoint['optimizer_state'])
445         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
446
447     except FileNotFoundError:
448         log_string('Starting from scratch.')
449
450     except:
451         log_string('Error when loading the checkpoint.')
452         exit(1)
453
454 ######################################################################
455
456 for k in range(nb_epochs_finished, args.nb_epochs):
457
458     model.train()
459
460     nb_train_samples, acc_train_loss = 0, 0.0
461
462     for input in task.batches(split = 'train'):
463         input = input.to(device)
464         output = model(input)
465         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
466         acc_train_loss += loss.item() * input.size(0)
467         nb_train_samples += input.size(0)
468
469         optimizer.zero_grad()
470         loss.backward()
471         optimizer.step()
472
473     with torch.autograd.no_grad():
474
475         model.eval()
476
477         nb_test_samples, acc_test_loss = 0, 0.0
478
479         for input in task.batches(split = 'test'):
480             input = input.to(device)
481             output = model(input)
482             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
483             acc_test_loss += loss.item() * input.size(0)
484             nb_test_samples += input.size(0)
485
486         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
487         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
488
489         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
490
491         task.produce_results(k, model)
492
493     checkpoint = {
494         'nb_epochs_finished': k + 1,
495         'model_state': model.state_dict(),
496         'optimizer_state': optimizer.state_dict()
497     }
498
499     torch.save(checkpoint, args.checkpoint_name)
500
501 ######################################################################