Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     type = bool, default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     type = bool, default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ######################################################################
76
77 args = parser.parse_args()
78
79 log_file = open(args.log_filename, 'w')
80
81 if args.seed >= 0:
82     torch.manual_seed(args.seed)
83
84 ######################################################################
85
86 def log_string(s):
87     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
88
89     if log_file is not None:
90         log_file.write(t + s + '\n')
91         log_file.flush()
92
93     print(t + s)
94     sys.stdout.flush()
95
96 for n in vars(args):
97     log_string(f'args.{n} {getattr(args, n)}')
98
99 ######################################################################
100
101 class Task:
102     def batches(self, split = 'train'):
103         pass
104
105     def vocabulary_size(self):
106         pass
107
108     def produce_results(self, n_epoch, model, nb_tokens = 50):
109         pass
110
111 ######################################################################
112
113 import picoclvr
114
115 class TaskPicoCLVR(Task):
116
117     def __init__(self, batch_size,
118                  height = 6, width = 8, many_colors = False,
119                  device = torch.device('cpu')):
120
121         self.batch_size = batch_size
122         self.device = device
123         nb = args.data_size if args.data_size > 0 else 250000
124
125         descr = picoclvr.generate(
126             nb,
127             height = height, width = width,
128             many_colors = many_colors
129         )
130
131         descr = [ s.strip().split(' ') for s in descr ]
132         l = max([ len(s) for s in descr ])
133         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
134
135         tokens = set()
136         for s in descr:
137             for t in s: tokens.add(t)
138         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
139         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
140
141         t = [ [ self.token2id[u] for u in s ] for s in descr ]
142         data_input = torch.tensor(t, device = self.device)
143
144         self.test_input = data_input[:nb // 5]
145         self.train_input = data_input[nb // 5:]
146
147     def batches(self, split = 'train'):
148         assert split in { 'train', 'test' }
149         if split == 'train':
150             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
151                 yield batch
152         else:
153             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
154                 yield batch
155
156     def vocabulary_size(self):
157         return len(self.token2id)
158
159     def produce_results(self, n_epoch, model, nb_tokens = 50):
160         img = [ ]
161         nb_per_primer = 8
162
163         for primer in [
164                 'red above green <sep> green top <sep> blue right of red <img>',
165                 'there is red <sep> there is yellow <sep> there is blue <img>',
166                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
167                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
168         ]:
169
170             for k in range(nb_per_primer):
171                 t_primer = primer.strip().split(' ')
172                 t_generated = [ ]
173
174                 for j in range(nb_tokens):
175                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
176                     input = torch.tensor(t, device = self.device)
177                     output = model(input)
178                     logits = output[0, -1]
179                     if args.synthesis_sampling:
180                         dist = torch.distributions.categorical.Categorical(logits = logits)
181                         t = dist.sample()
182                     else:
183                         t = logits.argmax()
184                     t_generated.append(self.id2token[t.item()])
185
186                 descr = [ ' '.join(t_primer + t_generated) ]
187                 img += [ picoclvr.descr2img(descr) ]
188
189         img = torch.cat(img, 0)
190         file_name = f'result_picoclvr_{n_epoch:04d}.png'
191         torchvision.utils.save_image(img / 255.,
192                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
193         log_string(f'wrote {file_name}')
194
195 ######################################################################
196
197 class TaskWiki103(Task):
198
199     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
200                  device = torch.device('cpu')):
201
202         self.batch_size = batch_size
203         self.len_min = len_min
204         self.len_max = len_max
205         self.min_freq = min_freq
206         self.device = device
207
208         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
209         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
210
211         # Mostly for debug
212         if args.data_size > 0:
213             train_iter = itertools.islice(train_iter, args.data_size)
214
215         def yield_tokens():
216             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
217                 yield self.tokenizer(l)
218
219         self.vocab = torchtext.vocab.build_vocab_from_iterator(
220             yield_tokens(),
221             specials = [ '<unk>', '<non>' ],
222             min_freq = self.min_freq
223         )
224
225         self.vocab.set_default_index(self.vocab[ '<unk>' ])
226
227     def tensorize(self, s):
228         a = max(len(x) for x in s)
229         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
230
231     def yield_batches(self, ds):
232         s = [ ]
233         for l in ds:
234             q = self.tokenizer(l)
235             if len(q) >= self.len_min and len(q) <= self.len_max:
236                 s += [ q ]
237                 if len(s) == self.batch_size:
238                     yield self.tensorize(s)
239                     s = [ ]
240
241         if len(s) > 0:
242             yield self.tensorize(s)
243
244     def batches(self, split = 'train'):
245         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
246
247         # Mostly for debug
248         if args.data_size > 0:
249             data_iter = itertools.islice(data_iter, args.data_size)
250
251         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
252
253     def vocabulary_size(self):
254         return len(self.vocab)
255
256     def produce_results(self, n_epoch, model, nb_tokens = 50):
257         file_name = f'result_wiki103_{n_epoch:04d}.txt'
258
259         with open(file_name, 'w') as outfile:
260              for primer in [
261                      'the cat is hunting a',
262                      'paris is the capital',
263                      'cars are convenient',
264                      'the difference between men and women is',
265                      'the object was blue all over and green all over it was',
266                      'cherries are red and lemons are',
267                      'cherries are sweet and lemons are',
268                      'two plus three equals',
269                      'deep learning is',
270              ]:
271                  t_primer = self.tokenizer(primer)
272                  t_generated = [ ]
273
274                  for j in range(nb_tokens):
275
276                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
277                      output = model(input)
278                      logits = output[0, -1]
279                      if args.synthesis_sampling:
280                          dist = torch.distributions.categorical.Categorical(logits = logits)
281                          t = dist.sample()
282                      else:
283                          t = logits.argmax()
284                      t_generated.append(self.vocab.lookup_token(t))
285                      if t_generated[-1] == '<non>': break
286
287                  s = ' '.join(t_generated)
288
289                  outfile.write(f'<{primer}> {s}\n')
290
291         log_string(f'wrote {file_name}')
292
293 ######################################################################
294
295 class TaskMNIST(Task):
296
297     def __init__(self, batch_size, device = torch.device('cpu')):
298         self.device = device
299         self.batch_size = batch_size
300
301     def batches(self, split = 'train'):
302         assert split in { 'train', 'test' }
303         data_set = torchvision.datasets.MNIST(
304             root = './data', train = (split == 'train'),
305             download = True
306         )
307         data_input = data_set.data.view(-1, 28 * 28).long()
308         if args.data_size >= 0:
309             data_input = data_input[:args.data_size]
310         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
311             yield batch
312
313     def vocabulary_size(self):
314         return 256
315
316     def produce_results(self, n_epoch, model, nb_samples = 64):
317         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
318         for input in results.split(self.batch_size):
319             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
320                 output = model(input)
321                 logits = output[:, s]
322                 if args.synthesis_sampling:
323                     dist = torch.distributions.categorical.Categorical(logits = logits)
324                     t = dist.sample()
325                 else:
326                     t = logits.argmax(1)
327                 input[:, s + 1] = t
328
329         image_name = f'result_mnist_{n_epoch:04d}.png'
330         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
331                                      image_name, nrow = 16, pad_value = 0.8)
332         log_string(f'wrote {image_name}')
333
334 ######################################################################
335
336 def check_causality(model):
337     #m = model[1:]
338     input = torch.rand(1, 5, dim_model).requires_grad_()
339     output = m(input)
340     a = torch.zeros(output.size(1), input.size(1))
341     for k in range(output.size(1)):
342         for d in range(output.size(2)):
343             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
344             a[k] += g.squeeze(0).pow(2).sum(1)
345     print(a)
346
347 ######################################################################
348
349 log_string(f'device {device}')
350
351 if args.data == 'wiki103':
352     task = TaskWiki103(batch_size = args.batch_size, device = device)
353 elif args.data == 'mnist':
354     task = TaskMNIST(batch_size = args.batch_size, device = device)
355 elif args.data == 'picoclvr':
356     task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
357 else:
358     raise ValueError(f'Unknown dataset {args.data}.')
359
360 vocabulary_size = task.vocabulary_size()
361
362 log_string(f'vocabulary_size {vocabulary_size}')
363
364 ##############################
365
366 model = mygpt.MyGPT(
367     vocabulary_size = vocabulary_size,
368     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
369     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
370 )
371
372 model.to(device)
373
374 nb_parameters = sum(p.numel() for p in model.parameters())
375 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
376
377 ######################################################################
378
379 if args.optim == 'sgd':
380     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
381 elif args.optim == 'adam':
382     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
383 elif args.optim == 'adamw':
384     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
385 else:
386     raise ValueError(f'Unknown optimizer {args.optim}.')
387
388 ######################################################################
389
390 nb_epochs_finished = 0
391
392 try:
393     checkpoint = torch.load(args.checkpoint_name, map_location = device)
394     nb_epochs_finished = checkpoint['nb_epochs_finished']
395     model.load_state_dict(checkpoint['model_state'])
396     optimizer.load_state_dict(checkpoint['optimizer_state'])
397     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
398
399 except FileNotFoundError:
400     print('Starting from scratch.')
401
402 except:
403     print('Error when loading the checkpoint.')
404     exit(1)
405
406 ######################################################################
407
408 for k in range(nb_epochs_finished, args.nb_epochs):
409
410     model.train()
411
412     nb_train_samples, acc_train_loss = 0, 0.0
413
414     for input in task.batches(split = 'train'):
415         input = input.to(device)
416         output = model(input)
417         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
418         acc_train_loss += loss.item() * input.size(0)
419         nb_train_samples += input.size(0)
420
421         optimizer.zero_grad()
422         loss.backward()
423         optimizer.step()
424
425     with torch.autograd.no_grad():
426
427         model.eval()
428
429         nb_test_samples, acc_test_loss = 0, 0.0
430
431         for input in task.batches(split = 'test'):
432             input = input.to(device)
433             output = model(input)
434             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
435             acc_test_loss += loss.item() * input.size(0)
436             nb_test_samples += input.size(0)
437
438         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
439         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
440
441         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
442
443         task.produce_results(k, model)
444
445     checkpoint = {
446         'nb_epochs_finished': k + 1,
447         'model_state': model.state_dict(),
448         'optimizer_state': optimizer.state_dict()
449     }
450
451     torch.save(checkpoint, args.checkpoint_name)
452
453 ######################################################################