f496e9953777d898da5c2938b902335cb047477e
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_many_colors',
82                     action='store_true', default = False)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 class Task:
117     def batches(self, split = 'train'):
118         pass
119
120     def vocabulary_size(self):
121         pass
122
123     def produce_results(self, n_epoch, model, nb_tokens = 50):
124         pass
125
126 ######################################################################
127
128 import picoclvr
129
130 class TaskPicoCLVR(Task):
131
132     def __init__(self, batch_size,
133                  height, width, many_colors = False,
134                  device = torch.device('cpu')):
135
136         def generate_descr(nb):
137             descr = picoclvr.generate(
138                 nb,
139                 height = self.height, width = self.width,
140                 many_colors = many_colors
141             )
142
143             descr = [ s.strip().split(' ') for s in descr ]
144             l = max([ len(s) for s in descr ])
145             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
146
147             return descr
148
149         self.height = height
150         self.width = width
151         self.batch_size = batch_size
152         self.device = device
153         nb = args.data_size if args.data_size > 0 else 250000
154
155         self.train_descr = generate_descr((nb * 4) // 5)
156         self.test_descr = generate_descr((nb * 1) // 5)
157
158         tokens = set()
159         for d in [ self.train_descr, self.test_descr ]:
160             for s in d:
161                 for t in s: tokens.add(t)
162         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
163         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
164
165         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
166         self.train_input = torch.tensor(t, device = self.device)
167         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
168         self.test_input = torch.tensor(t, device = self.device)
169
170     def batches(self, split = 'train'):
171         assert split in { 'train', 'test' }
172         if split == 'train':
173             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
174                 yield batch
175         else:
176             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
177                 yield batch
178
179     def vocabulary_size(self):
180         return len(self.token2id)
181
182     def generate(self, primer, model, nb_tokens):
183         t_primer = primer.strip().split(' ')
184         t_generated = [ ]
185
186         for j in range(nb_tokens):
187             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
188             input = torch.tensor(t, device = self.device)
189             output = model(input)
190             logits = output[0, -1]
191             if args.synthesis_sampling:
192                 dist = torch.distributions.categorical.Categorical(logits = logits)
193                 t = dist.sample()
194             else:
195                 t = logits.argmax()
196             t_generated.append(self.id2token[t.item()])
197
198         return ' '.join(t_primer + t_generated)
199
200     def produce_results(self, n_epoch, model, nb_tokens = None):
201         if nb_tokens is None:
202             nb_tokens = self.height * self.width + 3
203         descr = [ ]
204         nb_per_primer = 8
205
206         for primer in [
207                 'red above green <sep> green top <sep> blue right of red <img>',
208                 'there is red <sep> there is yellow <sep> there is blue <img>',
209                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
210                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
211         ]:
212
213             for k in range(nb_per_primer):
214                 descr.append(self.generate(primer, model, nb_tokens))
215
216         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
217         img = torch.cat(img, 0)
218         file_name = f'result_picoclvr_{n_epoch:04d}.png'
219         torchvision.utils.save_image(
220             img / 255.,
221             file_name, nrow = nb_per_primer, pad_value = 0.8
222         )
223         log_string(f'wrote {file_name}')
224
225         nb_missing = sum( [
226             x[2] for x in picoclvr.nb_missing_properties(
227                 descr,
228                 height = self.height, width = self.width
229             )
230         ] )
231
232         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
233
234 ######################################################################
235
236 class TaskWiki103(Task):
237
238     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
239                  device = torch.device('cpu')):
240
241         self.batch_size = batch_size
242         self.len_min = len_min
243         self.len_max = len_max
244         self.min_freq = min_freq
245         self.device = device
246
247         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
248         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
249
250         # Mostly for debug
251         if args.data_size > 0:
252             train_iter = itertools.islice(train_iter, args.data_size)
253
254         def yield_tokens():
255             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
256                 yield self.tokenizer(l)
257
258         self.vocab = torchtext.vocab.build_vocab_from_iterator(
259             yield_tokens(),
260             specials = [ '<unk>', '<non>' ],
261             min_freq = self.min_freq
262         )
263
264         self.vocab.set_default_index(self.vocab[ '<unk>' ])
265
266     def tensorize(self, s):
267         a = max(len(x) for x in s)
268         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
269
270     def yield_batches(self, ds):
271         s = [ ]
272         for l in ds:
273             q = self.tokenizer(l)
274             if len(q) >= self.len_min and len(q) <= self.len_max:
275                 s += [ q ]
276                 if len(s) == self.batch_size:
277                     yield self.tensorize(s)
278                     s = [ ]
279
280         if len(s) > 0:
281             yield self.tensorize(s)
282
283     def batches(self, split = 'train'):
284         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
285
286         # Mostly for debug
287         if args.data_size > 0:
288             data_iter = itertools.islice(data_iter, args.data_size)
289
290         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
291
292     def vocabulary_size(self):
293         return len(self.vocab)
294
295     def produce_results(self, n_epoch, model, nb_tokens = 50):
296         file_name = f'result_wiki103_{n_epoch:04d}.txt'
297
298         with open(file_name, 'w') as outfile:
299              for primer in [
300                      'the cat is hunting a',
301                      'paris is the capital',
302                      'cars are convenient',
303                      'the difference between men and women is',
304                      'the object was blue all over and green all over it was',
305                      'cherries are red and lemons are',
306                      'cherries are sweet and lemons are',
307                      'two plus three equals',
308                      'deep learning is',
309              ]:
310                  t_primer = self.tokenizer(primer)
311                  t_generated = [ ]
312
313                  for j in range(nb_tokens):
314
315                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
316                      output = model(input)
317                      logits = output[0, -1]
318                      if args.synthesis_sampling:
319                          dist = torch.distributions.categorical.Categorical(logits = logits)
320                          t = dist.sample()
321                      else:
322                          t = logits.argmax()
323                      t_generated.append(self.vocab.lookup_token(t))
324                      if t_generated[-1] == '<non>': break
325
326                  s = ' '.join(t_generated)
327
328                  outfile.write(f'<{primer}> {s}\n')
329
330         log_string(f'wrote {file_name}')
331
332 ######################################################################
333
334 class TaskMNIST(Task):
335
336     def __init__(self, batch_size, device = torch.device('cpu')):
337         self.device = device
338         self.batch_size = batch_size
339
340     def batches(self, split = 'train'):
341         assert split in { 'train', 'test' }
342         data_set = torchvision.datasets.MNIST(
343             root = './data', train = (split == 'train'),
344             download = True
345         )
346         data_input = data_set.data.view(-1, 28 * 28).long()
347         if args.data_size >= 0:
348             data_input = data_input[:args.data_size]
349         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
350             yield batch
351
352     def vocabulary_size(self):
353         return 256
354
355     def produce_results(self, n_epoch, model, nb_samples = 64):
356         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
357         for input in results.split(self.batch_size):
358             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
359                 output = model(input)
360                 logits = output[:, s]
361                 if args.synthesis_sampling:
362                     dist = torch.distributions.categorical.Categorical(logits = logits)
363                     t = dist.sample()
364                 else:
365                     t = logits.argmax(1)
366                 input[:, s + 1] = t
367
368         image_name = f'result_mnist_{n_epoch:04d}.png'
369         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
370                                      image_name, nrow = 16, pad_value = 0.8)
371         log_string(f'wrote {image_name}')
372
373 ######################################################################
374
375 def check_causality(model):
376     #m = model[1:]
377     input = torch.rand(1, 5, dim_model).requires_grad_()
378     output = m(input)
379     a = torch.zeros(output.size(1), input.size(1))
380     for k in range(output.size(1)):
381         for d in range(output.size(2)):
382             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
383             a[k] += g.squeeze(0).pow(2).sum(1)
384     print(a)
385
386 ######################################################################
387
388 log_string(f'device {device}')
389
390 if args.data == 'wiki103':
391     task = TaskWiki103(batch_size = args.batch_size, device = device)
392 elif args.data == 'mnist':
393     task = TaskMNIST(batch_size = args.batch_size, device = device)
394 elif args.data == 'picoclvr':
395     task = TaskPicoCLVR(batch_size = args.batch_size,
396                         height = args.picoclvr_height,
397                         width = args.picoclvr_width,
398                         many_colors = args.picoclvr_many_colors,
399                         device = device)
400 else:
401     raise ValueError(f'Unknown dataset {args.data}.')
402
403 vocabulary_size = task.vocabulary_size()
404
405 log_string(f'vocabulary_size {vocabulary_size}')
406
407 ##############################
408
409 model = mygpt.MyGPT(
410     vocabulary_size = vocabulary_size,
411     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
412     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
413 )
414
415 model.to(device)
416
417 nb_parameters = sum(p.numel() for p in model.parameters())
418 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
419
420 ######################################################################
421
422 if args.optim == 'sgd':
423     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
424 elif args.optim == 'adam':
425     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adamw':
427     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
428 else:
429     raise ValueError(f'Unknown optimizer {args.optim}.')
430
431 ######################################################################
432
433 nb_epochs_finished = 0
434
435 if args.no_checkpoint:
436     log_string(f'Not trying to load checkpoint.')
437
438 else:
439     try:
440         checkpoint = torch.load(args.checkpoint_name, map_location = device)
441         nb_epochs_finished = checkpoint['nb_epochs_finished']
442         model.load_state_dict(checkpoint['model_state'])
443         optimizer.load_state_dict(checkpoint['optimizer_state'])
444         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
445
446     except FileNotFoundError:
447         log_string('Starting from scratch.')
448
449     except:
450         log_string('Error when loading the checkpoint.')
451         exit(1)
452
453 ######################################################################
454
455 for k in range(nb_epochs_finished, args.nb_epochs):
456
457     model.train()
458
459     nb_train_samples, acc_train_loss = 0, 0.0
460
461     for input in task.batches(split = 'train'):
462         input = input.to(device)
463         output = model(input)
464         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
465         acc_train_loss += loss.item() * input.size(0)
466         nb_train_samples += input.size(0)
467
468         optimizer.zero_grad()
469         loss.backward()
470         optimizer.step()
471
472     with torch.autograd.no_grad():
473
474         model.eval()
475
476         nb_test_samples, acc_test_loss = 0, 0.0
477
478         for input in task.batches(split = 'test'):
479             input = input.to(device)
480             output = model(input)
481             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
482             acc_test_loss += loss.item() * input.size(0)
483             nb_test_samples += input.size(0)
484
485         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
486         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
487
488         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
489
490         task.produce_results(k, model)
491
492     checkpoint = {
493         'nb_epochs_finished': k + 1,
494         'model_state': model.state_dict(),
495         'optimizer_state': optimizer.state_dict()
496     }
497
498     torch.save(checkpoint, args.checkpoint_name)
499
500 ######################################################################