Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 parser.add_argument('--picoclvr_many_colors',
76                     action='store_true', default = False)
77
78 ######################################################################
79
80 args = parser.parse_args()
81
82 log_file = open(args.log_filename, 'w')
83
84 if args.seed >= 0:
85     torch.manual_seed(args.seed)
86
87 ######################################################################
88
89 def log_string(s):
90     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
91
92     if log_file is not None:
93         log_file.write(t + s + '\n')
94         log_file.flush()
95
96     print(t + s)
97     sys.stdout.flush()
98
99 for n in vars(args):
100     log_string(f'args.{n} {getattr(args, n)}')
101
102 ######################################################################
103
104 class Task:
105     def batches(self, split = 'train'):
106         pass
107
108     def vocabulary_size(self):
109         pass
110
111     def produce_results(self, n_epoch, model, nb_tokens = 50):
112         pass
113
114 ######################################################################
115
116 import picoclvr
117
118 class TaskPicoCLVR(Task):
119
120     def __init__(self, batch_size,
121                  height = 6, width = 8, many_colors = False,
122                  device = torch.device('cpu')):
123
124         self.batch_size = batch_size
125         self.device = device
126         nb = args.data_size if args.data_size > 0 else 250000
127
128         descr = picoclvr.generate(
129             nb,
130             height = height, width = width,
131             many_colors = many_colors
132         )
133
134         # self.test_descr = descr[:nb // 5]
135         # self.train_descr = descr[nb // 5:]
136
137         descr = [ s.strip().split(' ') for s in descr ]
138         l = max([ len(s) for s in descr ])
139         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
140
141         tokens = set()
142         for s in descr:
143             for t in s: tokens.add(t)
144         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
145         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
146
147         t = [ [ self.token2id[u] for u in s ] for s in descr ]
148         data_input = torch.tensor(t, device = self.device)
149
150         self.test_input = data_input[:nb // 5]
151         self.train_input = data_input[nb // 5:]
152
153     def batches(self, split = 'train'):
154         assert split in { 'train', 'test' }
155         if split == 'train':
156             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
157                 yield batch
158         else:
159             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
160                 yield batch
161
162     def vocabulary_size(self):
163         return len(self.token2id)
164
165     def generate(self, primer, model, nb_tokens):
166         t_primer = primer.strip().split(' ')
167         t_generated = [ ]
168
169         for j in range(nb_tokens):
170             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
171             input = torch.tensor(t, device = self.device)
172             output = model(input)
173             logits = output[0, -1]
174             if args.synthesis_sampling:
175                 dist = torch.distributions.categorical.Categorical(logits = logits)
176                 t = dist.sample()
177             else:
178                 t = logits.argmax()
179             t_generated.append(self.id2token[t.item()])
180
181         return ' '.join(t_primer + t_generated)
182
183     def produce_results(self, n_epoch, model, nb_tokens = 50):
184         descr = [ ]
185         nb_per_primer = 8
186
187         for primer in [
188                 'red above green <sep> green top <sep> blue right of red <img>',
189                 'there is red <sep> there is yellow <sep> there is blue <img>',
190                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
191                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
192         ]:
193
194             for k in range(nb_per_primer):
195                 descr.append(self.generate(primer, model, nb_tokens))
196
197         img = [ picoclvr.descr2img(d) for d in descr ]
198         img = torch.cat(img, 0)
199         file_name = f'result_picoclvr_{n_epoch:04d}.png'
200         torchvision.utils.save_image(img / 255.,
201                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
202         log_string(f'wrote {file_name}')
203
204         nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr) ] )
205         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
206
207 ######################################################################
208
209 class TaskWiki103(Task):
210
211     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
212                  device = torch.device('cpu')):
213
214         self.batch_size = batch_size
215         self.len_min = len_min
216         self.len_max = len_max
217         self.min_freq = min_freq
218         self.device = device
219
220         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
221         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
222
223         # Mostly for debug
224         if args.data_size > 0:
225             train_iter = itertools.islice(train_iter, args.data_size)
226
227         def yield_tokens():
228             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
229                 yield self.tokenizer(l)
230
231         self.vocab = torchtext.vocab.build_vocab_from_iterator(
232             yield_tokens(),
233             specials = [ '<unk>', '<non>' ],
234             min_freq = self.min_freq
235         )
236
237         self.vocab.set_default_index(self.vocab[ '<unk>' ])
238
239     def tensorize(self, s):
240         a = max(len(x) for x in s)
241         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
242
243     def yield_batches(self, ds):
244         s = [ ]
245         for l in ds:
246             q = self.tokenizer(l)
247             if len(q) >= self.len_min and len(q) <= self.len_max:
248                 s += [ q ]
249                 if len(s) == self.batch_size:
250                     yield self.tensorize(s)
251                     s = [ ]
252
253         if len(s) > 0:
254             yield self.tensorize(s)
255
256     def batches(self, split = 'train'):
257         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
258
259         # Mostly for debug
260         if args.data_size > 0:
261             data_iter = itertools.islice(data_iter, args.data_size)
262
263         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
264
265     def vocabulary_size(self):
266         return len(self.vocab)
267
268     def produce_results(self, n_epoch, model, nb_tokens = 50):
269         file_name = f'result_wiki103_{n_epoch:04d}.txt'
270
271         with open(file_name, 'w') as outfile:
272              for primer in [
273                      'the cat is hunting a',
274                      'paris is the capital',
275                      'cars are convenient',
276                      'the difference between men and women is',
277                      'the object was blue all over and green all over it was',
278                      'cherries are red and lemons are',
279                      'cherries are sweet and lemons are',
280                      'two plus three equals',
281                      'deep learning is',
282              ]:
283                  t_primer = self.tokenizer(primer)
284                  t_generated = [ ]
285
286                  for j in range(nb_tokens):
287
288                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
289                      output = model(input)
290                      logits = output[0, -1]
291                      if args.synthesis_sampling:
292                          dist = torch.distributions.categorical.Categorical(logits = logits)
293                          t = dist.sample()
294                      else:
295                          t = logits.argmax()
296                      t_generated.append(self.vocab.lookup_token(t))
297                      if t_generated[-1] == '<non>': break
298
299                  s = ' '.join(t_generated)
300
301                  outfile.write(f'<{primer}> {s}\n')
302
303         log_string(f'wrote {file_name}')
304
305 ######################################################################
306
307 class TaskMNIST(Task):
308
309     def __init__(self, batch_size, device = torch.device('cpu')):
310         self.device = device
311         self.batch_size = batch_size
312
313     def batches(self, split = 'train'):
314         assert split in { 'train', 'test' }
315         data_set = torchvision.datasets.MNIST(
316             root = './data', train = (split == 'train'),
317             download = True
318         )
319         data_input = data_set.data.view(-1, 28 * 28).long()
320         if args.data_size >= 0:
321             data_input = data_input[:args.data_size]
322         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
323             yield batch
324
325     def vocabulary_size(self):
326         return 256
327
328     def produce_results(self, n_epoch, model, nb_samples = 64):
329         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
330         for input in results.split(self.batch_size):
331             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
332                 output = model(input)
333                 logits = output[:, s]
334                 if args.synthesis_sampling:
335                     dist = torch.distributions.categorical.Categorical(logits = logits)
336                     t = dist.sample()
337                 else:
338                     t = logits.argmax(1)
339                 input[:, s + 1] = t
340
341         image_name = f'result_mnist_{n_epoch:04d}.png'
342         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
343                                      image_name, nrow = 16, pad_value = 0.8)
344         log_string(f'wrote {image_name}')
345
346 ######################################################################
347
348 def check_causality(model):
349     #m = model[1:]
350     input = torch.rand(1, 5, dim_model).requires_grad_()
351     output = m(input)
352     a = torch.zeros(output.size(1), input.size(1))
353     for k in range(output.size(1)):
354         for d in range(output.size(2)):
355             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
356             a[k] += g.squeeze(0).pow(2).sum(1)
357     print(a)
358
359 ######################################################################
360
361 log_string(f'device {device}')
362
363 if args.data == 'wiki103':
364     task = TaskWiki103(batch_size = args.batch_size, device = device)
365 elif args.data == 'mnist':
366     task = TaskMNIST(batch_size = args.batch_size, device = device)
367 elif args.data == 'picoclvr':
368     task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
369 else:
370     raise ValueError(f'Unknown dataset {args.data}.')
371
372 vocabulary_size = task.vocabulary_size()
373
374 log_string(f'vocabulary_size {vocabulary_size}')
375
376 ##############################
377
378 model = mygpt.MyGPT(
379     vocabulary_size = vocabulary_size,
380     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
381     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
382 )
383
384 model.to(device)
385
386 nb_parameters = sum(p.numel() for p in model.parameters())
387 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
388
389 ######################################################################
390
391 if args.optim == 'sgd':
392     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
393 elif args.optim == 'adam':
394     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
395 elif args.optim == 'adamw':
396     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
397 else:
398     raise ValueError(f'Unknown optimizer {args.optim}.')
399
400 ######################################################################
401
402 nb_epochs_finished = 0
403
404 try:
405     checkpoint = torch.load(args.checkpoint_name, map_location = device)
406     nb_epochs_finished = checkpoint['nb_epochs_finished']
407     model.load_state_dict(checkpoint['model_state'])
408     optimizer.load_state_dict(checkpoint['optimizer_state'])
409     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
410
411 except FileNotFoundError:
412     print('Starting from scratch.')
413
414 except:
415     print('Error when loading the checkpoint.')
416     exit(1)
417
418 ######################################################################
419
420 for k in range(nb_epochs_finished, args.nb_epochs):
421
422     model.train()
423
424     nb_train_samples, acc_train_loss = 0, 0.0
425
426     for input in task.batches(split = 'train'):
427         input = input.to(device)
428         output = model(input)
429         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
430         acc_train_loss += loss.item() * input.size(0)
431         nb_train_samples += input.size(0)
432
433         optimizer.zero_grad()
434         loss.backward()
435         optimizer.step()
436
437     with torch.autograd.no_grad():
438
439         model.eval()
440
441         nb_test_samples, acc_test_loss = 0, 0.0
442
443         for input in task.batches(split = 'test'):
444             input = input.to(device)
445             output = model(input)
446             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
447             acc_test_loss += loss.item() * input.size(0)
448             nb_test_samples += input.size(0)
449
450         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
451         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
452
453         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
454
455         task.produce_results(k, model)
456
457     checkpoint = {
458         'nb_epochs_finished': k + 1,
459         'model_state': model.state_dict(),
460         'optimizer_state': optimizer.state_dict()
461     }
462
463     torch.save(checkpoint, args.checkpoint_name)
464
465 ######################################################################