83c305bd351d5f4c2ec4659679d2a54a1001ba61
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 parser.add_argument('--picoclvr_many_colors',
76                     action='store_true', default = False)
77
78 ######################################################################
79
80 args = parser.parse_args()
81
82 log_file = open(args.log_filename, 'w')
83
84 if args.seed >= 0:
85     torch.manual_seed(args.seed)
86
87 ######################################################################
88
89 def log_string(s):
90     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
91
92     if log_file is not None:
93         log_file.write(t + s + '\n')
94         log_file.flush()
95
96     print(t + s)
97     sys.stdout.flush()
98
99 for n in vars(args):
100     log_string(f'args.{n} {getattr(args, n)}')
101
102 ######################################################################
103
104 class Task:
105     def batches(self, split = 'train'):
106         pass
107
108     def vocabulary_size(self):
109         pass
110
111     def produce_results(self, n_epoch, model, nb_tokens = 50):
112         pass
113
114 ######################################################################
115
116 import picoclvr
117
118 class TaskPicoCLVR(Task):
119
120     def __init__(self, batch_size,
121                  height = 6, width = 8, many_colors = False,
122                  device = torch.device('cpu')):
123
124         self.batch_size = batch_size
125         self.device = device
126         nb = args.data_size if args.data_size > 0 else 250000
127
128         descr = picoclvr.generate(
129             nb,
130             height = height, width = width,
131             many_colors = many_colors
132         )
133
134         # self.test_descr = descr[:nb // 5]
135         # self.train_descr = descr[nb // 5:]
136
137         descr = [ s.strip().split(' ') for s in descr ]
138         l = max([ len(s) for s in descr ])
139         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
140
141         tokens = set()
142         for s in descr:
143             for t in s: tokens.add(t)
144         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
145         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
146
147         t = [ [ self.token2id[u] for u in s ] for s in descr ]
148         data_input = torch.tensor(t, device = self.device)
149
150         self.test_input = data_input[:nb // 5]
151         self.train_input = data_input[nb // 5:]
152
153     def batches(self, split = 'train'):
154         assert split in { 'train', 'test' }
155         if split == 'train':
156             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
157                 yield batch
158         else:
159             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
160                 yield batch
161
162     def vocabulary_size(self):
163         return len(self.token2id)
164
165     def generate(self, primer, model, nb_tokens):
166         t_primer = primer.strip().split(' ')
167         t_generated = [ ]
168
169         for j in range(nb_tokens):
170             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
171             input = torch.tensor(t, device = self.device)
172             output = model(input)
173             logits = output[0, -1]
174             if args.synthesis_sampling:
175                 dist = torch.distributions.categorical.Categorical(logits = logits)
176                 t = dist.sample()
177             else:
178                 t = logits.argmax()
179             t_generated.append(self.id2token[t.item()])
180
181         return ' '.join(t_primer + t_generated)
182
183     def produce_results(self, n_epoch, model, nb_tokens = 50):
184         descr = [ ]
185         nb_per_primer = 8
186
187         for primer in [
188                 'red above green <sep> green top <sep> blue right of red <img>',
189                 'there is red <sep> there is yellow <sep> there is blue <img>',
190                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
191                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
192         ]:
193
194             for k in range(nb_per_primer):
195                 descr.append(self.generate(primer, model, nb_tokens))
196
197         img = [ picoclvr.descr2img(d) for d in descr ]
198         img = torch.cat(img, 0)
199         file_name = f'result_picoclvr_{n_epoch:04d}.png'
200         torchvision.utils.save_image(img / 255.,
201                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
202         log_string(f'wrote {file_name}')
203
204         log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}')
205
206 ######################################################################
207
208 class TaskWiki103(Task):
209
210     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
211                  device = torch.device('cpu')):
212
213         self.batch_size = batch_size
214         self.len_min = len_min
215         self.len_max = len_max
216         self.min_freq = min_freq
217         self.device = device
218
219         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
220         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
221
222         # Mostly for debug
223         if args.data_size > 0:
224             train_iter = itertools.islice(train_iter, args.data_size)
225
226         def yield_tokens():
227             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
228                 yield self.tokenizer(l)
229
230         self.vocab = torchtext.vocab.build_vocab_from_iterator(
231             yield_tokens(),
232             specials = [ '<unk>', '<non>' ],
233             min_freq = self.min_freq
234         )
235
236         self.vocab.set_default_index(self.vocab[ '<unk>' ])
237
238     def tensorize(self, s):
239         a = max(len(x) for x in s)
240         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
241
242     def yield_batches(self, ds):
243         s = [ ]
244         for l in ds:
245             q = self.tokenizer(l)
246             if len(q) >= self.len_min and len(q) <= self.len_max:
247                 s += [ q ]
248                 if len(s) == self.batch_size:
249                     yield self.tensorize(s)
250                     s = [ ]
251
252         if len(s) > 0:
253             yield self.tensorize(s)
254
255     def batches(self, split = 'train'):
256         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
257
258         # Mostly for debug
259         if args.data_size > 0:
260             data_iter = itertools.islice(data_iter, args.data_size)
261
262         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
263
264     def vocabulary_size(self):
265         return len(self.vocab)
266
267     def produce_results(self, n_epoch, model, nb_tokens = 50):
268         file_name = f'result_wiki103_{n_epoch:04d}.txt'
269
270         with open(file_name, 'w') as outfile:
271              for primer in [
272                      'the cat is hunting a',
273                      'paris is the capital',
274                      'cars are convenient',
275                      'the difference between men and women is',
276                      'the object was blue all over and green all over it was',
277                      'cherries are red and lemons are',
278                      'cherries are sweet and lemons are',
279                      'two plus three equals',
280                      'deep learning is',
281              ]:
282                  t_primer = self.tokenizer(primer)
283                  t_generated = [ ]
284
285                  for j in range(nb_tokens):
286
287                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
288                      output = model(input)
289                      logits = output[0, -1]
290                      if args.synthesis_sampling:
291                          dist = torch.distributions.categorical.Categorical(logits = logits)
292                          t = dist.sample()
293                      else:
294                          t = logits.argmax()
295                      t_generated.append(self.vocab.lookup_token(t))
296                      if t_generated[-1] == '<non>': break
297
298                  s = ' '.join(t_generated)
299
300                  outfile.write(f'<{primer}> {s}\n')
301
302         log_string(f'wrote {file_name}')
303
304 ######################################################################
305
306 class TaskMNIST(Task):
307
308     def __init__(self, batch_size, device = torch.device('cpu')):
309         self.device = device
310         self.batch_size = batch_size
311
312     def batches(self, split = 'train'):
313         assert split in { 'train', 'test' }
314         data_set = torchvision.datasets.MNIST(
315             root = './data', train = (split == 'train'),
316             download = True
317         )
318         data_input = data_set.data.view(-1, 28 * 28).long()
319         if args.data_size >= 0:
320             data_input = data_input[:args.data_size]
321         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
322             yield batch
323
324     def vocabulary_size(self):
325         return 256
326
327     def produce_results(self, n_epoch, model, nb_samples = 64):
328         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
329         for input in results.split(self.batch_size):
330             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
331                 output = model(input)
332                 logits = output[:, s]
333                 if args.synthesis_sampling:
334                     dist = torch.distributions.categorical.Categorical(logits = logits)
335                     t = dist.sample()
336                 else:
337                     t = logits.argmax(1)
338                 input[:, s + 1] = t
339
340         image_name = f'result_mnist_{n_epoch:04d}.png'
341         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
342                                      image_name, nrow = 16, pad_value = 0.8)
343         log_string(f'wrote {image_name}')
344
345 ######################################################################
346
347 def check_causality(model):
348     #m = model[1:]
349     input = torch.rand(1, 5, dim_model).requires_grad_()
350     output = m(input)
351     a = torch.zeros(output.size(1), input.size(1))
352     for k in range(output.size(1)):
353         for d in range(output.size(2)):
354             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
355             a[k] += g.squeeze(0).pow(2).sum(1)
356     print(a)
357
358 ######################################################################
359
360 log_string(f'device {device}')
361
362 if args.data == 'wiki103':
363     task = TaskWiki103(batch_size = args.batch_size, device = device)
364 elif args.data == 'mnist':
365     task = TaskMNIST(batch_size = args.batch_size, device = device)
366 elif args.data == 'picoclvr':
367     task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
368 else:
369     raise ValueError(f'Unknown dataset {args.data}.')
370
371 vocabulary_size = task.vocabulary_size()
372
373 log_string(f'vocabulary_size {vocabulary_size}')
374
375 ##############################
376
377 model = mygpt.MyGPT(
378     vocabulary_size = vocabulary_size,
379     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
380     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
381 )
382
383 model.to(device)
384
385 nb_parameters = sum(p.numel() for p in model.parameters())
386 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
387
388 ######################################################################
389
390 if args.optim == 'sgd':
391     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
392 elif args.optim == 'adam':
393     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
394 elif args.optim == 'adamw':
395     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
396 else:
397     raise ValueError(f'Unknown optimizer {args.optim}.')
398
399 ######################################################################
400
401 nb_epochs_finished = 0
402
403 try:
404     checkpoint = torch.load(args.checkpoint_name, map_location = device)
405     nb_epochs_finished = checkpoint['nb_epochs_finished']
406     model.load_state_dict(checkpoint['model_state'])
407     optimizer.load_state_dict(checkpoint['optimizer_state'])
408     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
409
410 except FileNotFoundError:
411     print('Starting from scratch.')
412
413 except:
414     print('Error when loading the checkpoint.')
415     exit(1)
416
417 ######################################################################
418
419 for k in range(nb_epochs_finished, args.nb_epochs):
420
421     model.train()
422
423     nb_train_samples, acc_train_loss = 0, 0.0
424
425     for input in task.batches(split = 'train'):
426         input = input.to(device)
427         output = model(input)
428         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
429         acc_train_loss += loss.item() * input.size(0)
430         nb_train_samples += input.size(0)
431
432         optimizer.zero_grad()
433         loss.backward()
434         optimizer.step()
435
436     with torch.autograd.no_grad():
437
438         model.eval()
439
440         nb_test_samples, acc_test_loss = 0, 0.0
441
442         for input in task.batches(split = 'test'):
443             input = input.to(device)
444             output = model(input)
445             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
446             acc_test_loss += loss.item() * input.size(0)
447             nb_test_samples += input.size(0)
448
449         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
450         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
451
452         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
453
454         task.produce_results(k, model)
455
456     checkpoint = {
457         'nb_epochs_finished': k + 1,
458         'model_state': model.state_dict(),
459         'optimizer_state': optimizer.state_dict()
460     }
461
462     torch.save(checkpoint, args.checkpoint_name)
463
464 ######################################################################