7ce80a31728f49acfb5cd026afda6b5feed7890d
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = -1)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_nb_colors',
82                     type = int, default = 5)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 def autoregression(
117         model,
118         nb_samples, nb_tokens_to_generate, starting_input = None,
119         device = torch.device('cpu')
120 ):
121     first = 0
122     results = torch.zeros(
123         nb_samples, nb_tokens_to_generate,
124         dtype = torch.int64, device = device
125     )
126
127     if starting_input is not None:
128         first = starting_input.size(1)
129         results = torch.cat((starting_input, results), 1)
130
131     for input in results.split(self.batch_size):
132         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
133             output = model(input)
134             logits = output[:, s]
135             if args.synthesis_sampling:
136                 dist = torch.distributions.categorical.Categorical(logits = logits)
137                 t_next = dist.sample()
138             else:
139                 t_next = logits.argmax(1)
140             input[:, s] = t_next
141
142     return results
143
144 ######################################################################
145
146 class Task:
147     def batches(self, split = 'train'):
148         pass
149
150     def vocabulary_size(self):
151         pass
152
153     def produce_results(self, n_epoch, model, nb_tokens = 50):
154         pass
155
156 ######################################################################
157
158 import picoclvr
159
160 class TaskPicoCLVR(Task):
161
162     def __init__(self, batch_size,
163                  height, width, nb_colors = 5,
164                  device = torch.device('cpu')):
165
166         def generate_descr(nb):
167             descr = picoclvr.generate(
168                 nb,
169                 height = self.height, width = self.width,
170                 nb_colors = nb_colors
171             )
172
173             descr = [ s.strip().split(' ') for s in descr ]
174             l = max([ len(s) for s in descr ])
175             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
176
177             return descr
178
179         self.height = height
180         self.width = width
181         self.batch_size = batch_size
182         self.device = device
183         nb = args.data_size if args.data_size > 0 else 250000
184
185         self.train_descr = generate_descr((nb * 4) // 5)
186         self.test_descr = generate_descr((nb * 1) // 5)
187
188         # Build the tokenizer
189         tokens = set()
190         for d in [ self.train_descr, self.test_descr ]:
191             for s in d:
192                 for t in s: tokens.add(t)
193         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
194         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
195
196         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
197         self.train_input = torch.tensor(t, device = self.device)
198         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
199         self.test_input = torch.tensor(t, device = self.device)
200
201     def batches(self, split = 'train'):
202         assert split in { 'train', 'test' }
203         if split == 'train':
204             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
205                 yield batch
206         else:
207             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
208                 yield batch
209
210     def vocabulary_size(self):
211         return len(self.token2id)
212
213     def generate(self, primer, model, nb_tokens):
214         t_primer = primer.strip().split(' ')
215         t_generated = [ ]
216
217         for j in range(nb_tokens):
218             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
219             input = torch.tensor(t, device = self.device)
220             input = F.pad(input, (0, 1)) # Add the next token, the one to predict
221             output = model(input)
222             logits = output[0, -1]
223             if args.synthesis_sampling:
224                 dist = torch.distributions.categorical.Categorical(logits = logits)
225                 t_next = dist.sample()
226             else:
227                 t_next = logits.argmax()
228             t_generated.append(self.id2token[t_next.item()])
229
230         return ' '.join(t_primer + t_generated)
231
232     def produce_results(self, n_epoch, model, nb_tokens = None):
233         if nb_tokens is None:
234             nb_tokens = self.height * self.width + 3
235         descr = [ ]
236         nb_per_primer = 8
237
238         for primer in [
239                 'red above green <sep> green top <sep> blue right of red <img>',
240                 'there is red <sep> there is yellow <sep> there is blue <img>',
241                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
242                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
243         ]:
244
245             for k in range(nb_per_primer):
246                 descr.append(self.generate(primer, model, nb_tokens))
247
248         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
249         img = torch.cat(img, 0)
250         image_name = f'result_picoclvr_{n_epoch:04d}.png'
251         torchvision.utils.save_image(
252             img / 255.,
253             image_name, nrow = nb_per_primer, pad_value = 0.8
254         )
255         log_string(f'wrote {image_name}')
256
257         nb_missing = sum( [
258             x[2] for x in picoclvr.nb_missing_properties(
259                 descr,
260                 height = self.height, width = self.width
261             )
262         ] )
263
264         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
265
266 ######################################################################
267
268 class TaskWiki103(Task):
269
270     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
271                  device = torch.device('cpu')):
272
273         self.batch_size = batch_size
274         self.len_min = len_min
275         self.len_max = len_max
276         self.min_freq = min_freq
277         self.device = device
278
279         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
280         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
281
282         # Mostly for debug
283         if args.data_size > 0:
284             train_iter = itertools.islice(train_iter, args.data_size)
285
286         def yield_tokens():
287             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
288                 yield self.tokenizer(l)
289
290         self.vocab = torchtext.vocab.build_vocab_from_iterator(
291             yield_tokens(),
292             specials = [ '<unk>', '<non>' ],
293             min_freq = self.min_freq
294         )
295
296         self.vocab.set_default_index(self.vocab[ '<unk>' ])
297
298     def tensorize(self, s):
299         a = max(len(x) for x in s)
300         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
301
302     def yield_batches(self, ds):
303         s = [ ]
304         for l in ds:
305             q = self.tokenizer(l)
306             if len(q) >= self.len_min and len(q) <= self.len_max:
307                 s += [ q ]
308                 if len(s) == self.batch_size:
309                     yield self.tensorize(s)
310                     s = [ ]
311
312         if len(s) > 0:
313             yield self.tensorize(s)
314
315     def batches(self, split = 'train'):
316         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
317
318         # Mostly for debug
319         if args.data_size > 0:
320             data_iter = itertools.islice(data_iter, args.data_size)
321
322         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
323
324     def vocabulary_size(self):
325         return len(self.vocab)
326
327     def produce_results(self, n_epoch, model, nb_tokens = 50):
328         file_name = f'result_wiki103_{n_epoch:04d}.txt'
329
330         with open(file_name, 'w') as outfile:
331              for primer in [
332                      'the cat is hunting a',
333                      'paris is the capital',
334                      'cars are convenient',
335                      'the difference between men and women is',
336                      'the object was blue all over and green all over it was',
337                      'cherries are red and lemons are',
338                      'cherries are sweet and lemons are',
339                      'two plus three equals',
340                      'deep learning is',
341              ]:
342                  t_primer = self.tokenizer(primer)
343                  t_generated = [ ]
344
345                  for j in range(nb_tokens):
346
347                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
348                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
349                      output = model(input)
350                      logits = output[0, -1]
351                      if args.synthesis_sampling:
352                          dist = torch.distributions.categorical.Categorical(logits = logits)
353                          t_next = dist.sample()
354                      else:
355                          t_next = logits.argmax()
356                      t_generated.append(self.vocab.lookup_token(t_next))
357                      if t_generated[-1] == '<non>': break
358
359                  s = ' '.join(t_generated)
360
361                  outfile.write(f'<{primer}> {s}\n')
362
363         log_string(f'wrote {file_name}')
364
365 ######################################################################
366
367 class TaskMNIST(Task):
368
369     def __init__(self, batch_size, device = torch.device('cpu')):
370         self.device = device
371         self.batch_size = batch_size
372
373     def batches(self, split = 'train'):
374         assert split in { 'train', 'test' }
375         data_set = torchvision.datasets.MNIST(
376             root = './data', train = (split == 'train'),
377             download = True
378         )
379         data_input = data_set.data.view(-1, 28 * 28).long()
380         if args.data_size >= 0:
381             data_input = data_input[:args.data_size]
382         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
383             yield batch
384
385     def vocabulary_size(self):
386         return 256
387
388     def produce_results(self, n_epoch, model, nb_samples = 64):
389         results = autoregression(model, nb_samples, 28 * 28, device = self.device)
390         image_name = f'result_mnist_{n_epoch:04d}.png'
391         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392                                      image_name, nrow = 16, pad_value = 0.8)
393         log_string(f'wrote {image_name}')
394
395 ######################################################################
396
397 log_string(f'device {device}')
398
399 if args.data == 'wiki103':
400     nb_epochs_default = 10
401     task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403     nb_epochs_default = 25
404     task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406     nb_epochs_default = 10
407     task = TaskPicoCLVR(batch_size = args.batch_size,
408                         height = args.picoclvr_height,
409                         width = args.picoclvr_width,
410                         nb_colors = args.picoclvr_nb_colors,
411                         device = device)
412 else:
413     raise ValueError(f'Unknown dataset {args.data}.')
414
415 vocabulary_size = task.vocabulary_size()
416
417 log_string(f'vocabulary_size {vocabulary_size}')
418
419 ##############################
420
421 model = mygpt.MyGPT(
422     vocabulary_size = vocabulary_size,
423     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
425 )
426
427 model.to(device)
428
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
431
432 ######################################################################
433
434 if args.optim == 'sgd':
435     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
440 else:
441     raise ValueError(f'Unknown optimizer {args.optim}.')
442
443 ######################################################################
444
445 nb_epochs_finished = 0
446
447 if args.no_checkpoint:
448     log_string(f'Not trying to load checkpoint.')
449
450 else:
451     try:
452         checkpoint = torch.load(args.checkpoint_name, map_location = device)
453         nb_epochs_finished = checkpoint['nb_epochs_finished']
454         model.load_state_dict(checkpoint['model_state'])
455         optimizer.load_state_dict(checkpoint['optimizer_state'])
456         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
457
458     except FileNotFoundError:
459         log_string('Starting from scratch.')
460
461     except:
462         log_string('Error when loading the checkpoint.')
463         exit(1)
464
465 ######################################################################
466
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468
469 token_count = 0
470 for input in task.batches(split = 'train'):
471     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'Train set perplexity {train_set_perplexity}')
476
477 for k in range(nb_epochs_finished, nb_epochs):
478
479     model.train()
480
481     nb_train_samples, acc_train_loss = 0, 0.0
482
483     for input in task.batches(split = 'train'):
484         input = input.to(device)
485         output = model(input)
486         loss = F.cross_entropy(output.transpose(1, 2), input)
487         acc_train_loss += loss.item() * input.size(0)
488         nb_train_samples += input.size(0)
489
490         optimizer.zero_grad()
491         loss.backward()
492         optimizer.step()
493
494     with torch.autograd.no_grad():
495
496         model.eval()
497
498         nb_test_samples, acc_test_loss = 0, 0.0
499
500         for input in task.batches(split = 'test'):
501             input = input.to(device)
502             output = model(input)
503             loss = F.cross_entropy(output.transpose(1, 2), input)
504             acc_test_loss += loss.item() * input.size(0)
505             nb_test_samples += input.size(0)
506
507         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
509
510         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
511
512         task.produce_results(k, model)
513
514     checkpoint = {
515         'nb_epochs_finished': k + 1,
516         'model_state': model.state_dict(),
517         'optimizer_state': optimizer.state_dict()
518     }
519
520     torch.save(checkpoint, args.checkpoint_name)
521
522 ######################################################################