Replaced --synthesis_sampling with --deterministic_synthesis.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-3)
46
47 parser.add_argument('--learning_rate_end',
48                     type = float, default = 1e-6)
49
50 parser.add_argument('--dim_model',
51                     type = int, default = 512)
52
53 parser.add_argument('--dim_keys',
54                     type = int, default = 64)
55
56 parser.add_argument('--dim_hidden',
57                     type = int, default = 2048)
58
59 parser.add_argument('--nb_heads',
60                     type = int, default = 8)
61
62 parser.add_argument('--nb_blocks',
63                     type = int, default = 12)
64
65 parser.add_argument('--dropout',
66                     type = float, default = 0.1)
67
68 parser.add_argument('--deterministic_synthesis',
69                     action='store_true', default = False)
70
71 parser.add_argument('--no_checkpoint',
72                     action='store_true', default = False)
73
74 parser.add_argument('--checkpoint_name',
75                     type = str, default = 'checkpoint.pth')
76
77 ##############################
78 # picoclvr options
79
80 parser.add_argument('--picoclvr_nb_colors',
81                     type = int, default = 5)
82
83 parser.add_argument('--picoclvr_height',
84                     type = int, default = 12)
85
86 parser.add_argument('--picoclvr_width',
87                     type = int, default = 16)
88
89 ######################################################################
90
91 args = parser.parse_args()
92
93 log_file = open(args.log_filename, 'w')
94
95 if args.seed >= 0:
96     torch.manual_seed(args.seed)
97
98 ######################################################################
99
100 def log_string(s):
101     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
102
103     if log_file is not None:
104         log_file.write(t + s + '\n')
105         log_file.flush()
106
107     print(t + s)
108     sys.stdout.flush()
109
110 for n in vars(args):
111     log_string(f'args.{n} {getattr(args, n)}')
112
113 ######################################################################
114
115 def autoregression(
116         model, batch_size,
117         nb_samples, nb_tokens_to_generate, primer = None,
118         device = torch.device('cpu')
119 ):
120     results = torch.zeros(
121         nb_samples, nb_tokens_to_generate,
122         dtype = torch.int64, device = device
123     )
124
125     if primer is None:
126         first = 0
127     else:
128         first = primer.size(1)
129         results = torch.cat((primer, results), 1)
130
131     for input in results.split(batch_size):
132         for s in range(first, input.size(1)):
133             output = model(input)
134             logits = output[:, s]
135             if args.deterministic_synthesis:
136                 t_next = logits.argmax(1)
137             else:
138                 dist = torch.distributions.categorical.Categorical(logits = logits)
139                 t_next = dist.sample()
140             input[:, s] = t_next
141
142     return results
143
144 ######################################################################
145
146 class Task:
147     def batches(self, split = 'train'):
148         pass
149
150     def vocabulary_size(self):
151         pass
152
153     def produce_results(self, n_epoch, model):
154         pass
155
156 ######################################################################
157
158 import picoclvr
159
160 class TaskPicoCLVR(Task):
161
162     # Make a tensor from a list of strings
163     def tensorize(self, descr):
164         token_descr = [ s.strip().split(' ') for s in descr ]
165         l = max([ len(s) for s in token_descr ])
166         #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
167         token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
168         id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
169         return torch.tensor(id_descr, device = self.device)
170
171     def trim(self, x, token = '<nul>'):
172         n = self.token2id[token]
173         i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
174         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
175         return x[:, a:b]
176
177     def __init__(self, batch_size,
178                  height, width, nb_colors = 5,
179                  device = torch.device('cpu')):
180
181         def generate_descr(nb):
182             return picoclvr.generate(
183                 nb,
184                 height = self.height, width = self.width,
185                 nb_colors = nb_colors
186             )
187
188         self.height = height
189         self.width = width
190         self.batch_size = batch_size
191         self.device = device
192         nb = args.data_size if args.data_size > 0 else 250000
193
194         log_string(f'generating {nb} samples (can take some time)')
195         self.train_descr = generate_descr((nb * 4) // 5)
196         self.test_descr = generate_descr((nb * 1) // 5)
197
198         # Build the tokenizer
199         tokens = { '<nul>' }
200         for d in [ self.train_descr, self.test_descr ]:
201             for s in d:
202                 for t in s.strip().split(' '): tokens.add(t)
203         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
204         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
205
206         # Tokenize the train and test sets
207         self.train_input = self.tensorize(self.train_descr)
208         self.test_input = self.tensorize(self.test_descr)
209
210     def batches(self, split = 'train'):
211         assert split in { 'train', 'test' }
212         input = self.train_input if split == 'train' else self.test_input
213         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
214             yield self.trim(batch)
215
216     def vocabulary_size(self):
217         return len(self.token2id)
218
219     def produce_results(self, n_epoch, model):
220         nb_tokens_to_generate = self.height * self.width + 3
221         result_descr = [ ]
222         nb_per_primer = 8
223
224         for primer_descr in [
225                 'red above green <sep> green top <sep> blue right of red <img>',
226                 'there is red <sep> there is yellow <sep> there is blue <img>',
227                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
228                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
229         ]:
230
231             results = autoregression(
232                 model,
233                 self.batch_size,
234                 nb_samples = nb_per_primer,
235                 nb_tokens_to_generate = nb_tokens_to_generate,
236                 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
237                 device = self.device
238             )
239
240             l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
241             result_descr += l
242
243         np = picoclvr.nb_properties(
244             result_descr,
245             height = self.height, width = self.width
246         )
247
248         nb_requested_properties, _, nb_missing_properties = zip(*np)
249
250         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
251
252         img = [
253             picoclvr.descr2img(d, height = self.height, width = self.width)
254             for d in result_descr
255         ]
256
257         img = torch.cat(img, 0)
258         image_name = f'result_picoclvr_{n_epoch:04d}.png'
259         torchvision.utils.save_image(
260             img / 255.,
261             image_name, nrow = nb_per_primer, pad_value = 0.8
262         )
263         log_string(f'wrote {image_name}')
264
265 ######################################################################
266
267 class TaskWiki103(Task):
268
269     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
270                  device = torch.device('cpu')):
271
272         self.batch_size = batch_size
273         self.len_min = len_min
274         self.len_max = len_max
275         self.min_freq = min_freq
276         self.device = device
277
278         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
279         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
280
281         # Mostly for debug
282         if args.data_size > 0:
283             train_iter = itertools.islice(train_iter, args.data_size)
284
285         def yield_tokens():
286             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
287                 yield self.tokenizer(l)
288
289         self.vocab = torchtext.vocab.build_vocab_from_iterator(
290             yield_tokens(),
291             specials = [ '<unk>', '<nul>' ],
292             min_freq = self.min_freq
293         )
294
295         self.vocab.set_default_index(self.vocab[ '<unk>' ])
296
297     # makes a tensor from a list of list of tokens
298     def tensorize(self, s):
299         a = max(len(x) for x in s)
300         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
301
302     def yield_batches(self, ds):
303         s = [ ]
304         for l in ds:
305             q = self.tokenizer(l)
306             if len(q) >= self.len_min and len(q) <= self.len_max:
307                 s += [ q ]
308                 if len(s) == self.batch_size:
309                     yield self.tensorize(s)
310                     s = [ ]
311
312         if len(s) > 0:
313             yield self.tensorize(s)
314
315     def batches(self, split = 'train'):
316         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
317
318         # Mostly for debug
319         if args.data_size > 0:
320             data_iter = itertools.islice(data_iter, args.data_size)
321
322         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
323
324     def vocabulary_size(self):
325         return len(self.vocab)
326
327     def produce_results(self, n_epoch, model):
328         nb_tokens = 50
329         file_name = f'result_wiki103_{n_epoch:04d}.txt'
330
331         with open(file_name, 'w') as outfile:
332              for primer in [
333                      'the cat is hunting a',
334                      'paris is the capital',
335                      'cars are convenient',
336                      'the difference between men and women is',
337                      'the object was blue all over and green all over it was',
338                      'cherries are red and lemons are',
339                      'cherries are sweet and lemons are',
340                      'two plus three equals',
341                      'deep learning is',
342              ]:
343                  t_primer = self.tokenizer(primer)
344                  t_generated = [ ]
345
346                  for j in range(nb_tokens):
347
348                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
349                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
350                      output = model(input)
351                      logits = output[0, -1]
352                      if args.synthesis_sampling:
353                          dist = torch.distributions.categorical.Categorical(logits = logits)
354                          t_next = dist.sample()
355                      else:
356                          t_next = logits.argmax()
357                      t_generated.append(self.vocab.lookup_token(t_next))
358                      if t_generated[-1] == '<nul>': break
359
360                  s = ' '.join(t_generated)
361
362                  outfile.write(f'<{primer}> {s}\n')
363
364         log_string(f'wrote {file_name}')
365
366 ######################################################################
367
368 class TaskMNIST(Task):
369
370     def __init__(self, batch_size, device = torch.device('cpu')):
371         self.device = device
372         self.batch_size = batch_size
373
374     def batches(self, split = 'train'):
375         assert split in { 'train', 'test' }
376         data_set = torchvision.datasets.MNIST(
377             root = './data', train = (split == 'train'),
378             download = True
379         )
380         data_input = data_set.data.view(-1, 28 * 28).long()
381         if args.data_size >= 0:
382             data_input = data_input[:args.data_size]
383         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
384             yield batch
385
386     def vocabulary_size(self):
387         return 256
388
389     def produce_results(self, n_epoch, model):
390         nb_samples = 64
391         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
392         image_name = f'result_mnist_{n_epoch:04d}.png'
393         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
394                                      image_name, nrow = 16, pad_value = 0.8)
395         log_string(f'wrote {image_name}')
396
397 ######################################################################
398
399 log_string(f'device {device}')
400
401 if args.data == 'wiki103':
402     nb_epochs_default = 10
403     task = TaskWiki103(batch_size = args.batch_size, device = device)
404 elif args.data == 'mnist':
405     nb_epochs_default = 25
406     task = TaskMNIST(batch_size = args.batch_size, device = device)
407 elif args.data == 'picoclvr':
408     nb_epochs_default = 10
409     task = TaskPicoCLVR(batch_size = args.batch_size,
410                         height = args.picoclvr_height,
411                         width = args.picoclvr_width,
412                         nb_colors = args.picoclvr_nb_colors,
413                         device = device)
414 else:
415     raise ValueError(f'Unknown dataset {args.data}.')
416
417 vocabulary_size = task.vocabulary_size()
418
419 log_string(f'vocabulary_size {vocabulary_size}')
420
421 ##############################
422
423 model = mygpt.MyGPT(
424     vocabulary_size = vocabulary_size,
425     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
426     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
427 )
428
429 model.to(device)
430
431 nb_parameters = sum(p.numel() for p in model.parameters())
432 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
433
434 ######################################################################
435
436 nb_epochs_finished = 0
437
438 if args.no_checkpoint:
439     log_string(f'not trying to load checkpoint.')
440
441 else:
442     try:
443         checkpoint = torch.load(args.checkpoint_name)
444         nb_epochs_finished = checkpoint['nb_epochs_finished']
445         model.load_state_dict(checkpoint['model_state'])
446         torch.set_rng_state(checkpoint['rng_state'])
447         if torch.cuda.is_available():
448             torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
449         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
450
451     except FileNotFoundError:
452         log_string('starting from scratch.')
453
454     except:
455         log_string('error when loading the checkpoint.')
456         exit(1)
457
458 ######################################################################
459
460 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
461
462 token_count = 0
463 for input in task.batches(split = 'train'):
464     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
465 token_probas = token_count / token_count.sum()
466 entropy = -torch.xlogy(token_probas, token_probas).sum()
467 train_set_perplexity = math.exp(entropy)
468
469 for n_epoch in range(nb_epochs_finished, nb_epochs):
470
471     if args.learning_rate_end < 0:
472         lr = args.learning_rate
473     else:
474         u = n_epoch / (nb_epochs - 1)
475         lr = math.exp((1 - u) * math.log(args.learning_rate) +
476                       u * math.log(args.learning_rate_end))
477         log_string(f'learning_rate {lr}')
478
479     if args.optim == 'sgd':
480         optimizer = torch.optim.SGD(model.parameters(), lr = lr)
481     elif args.optim == 'adam':
482         optimizer = torch.optim.Adam(model.parameters(), lr = lr)
483     elif args.optim == 'adamw':
484         optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
485     else:
486         raise ValueError(f'Unknown optimizer {args.optim}.')
487
488     model.train()
489
490     nb_train_samples, acc_train_loss = 0, 0.0
491
492     for input in task.batches(split = 'train'):
493         input = input.to(device)
494         output = model(input)
495         loss = F.cross_entropy(output.transpose(1, 2), input)
496         acc_train_loss += loss.item() * input.size(0)
497         nb_train_samples += input.size(0)
498
499         optimizer.zero_grad()
500         loss.backward()
501         optimizer.step()
502
503     with torch.autograd.no_grad():
504
505         model.eval()
506
507         nb_test_samples, acc_test_loss = 0, 0.0
508
509         for input in task.batches(split = 'test'):
510             input = input.to(device)
511             output = model(input)
512             loss = F.cross_entropy(output.transpose(1, 2), input)
513             acc_test_loss += loss.item() * input.size(0)
514             nb_test_samples += input.size(0)
515
516         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
517         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
518
519         log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
520
521         task.produce_results(n_epoch, model)
522
523     checkpoint = {
524         'nb_epochs_finished': n_epoch + 1,
525         'model_state': model.state_dict(),
526         'rng_state': torch.get_rng_state(),
527     }
528
529     if torch.cuda.is_available():
530         checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
531
532     torch.save(checkpoint, args.checkpoint_name)
533
534 ######################################################################