b2adf986cb57c9bd447571be06d62379ad97d756
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--seed',
28                     type = int, default = 0)
29
30 parser.add_argument('--nb_epochs',
31                     type = int, default = -1)
32
33 parser.add_argument('--batch_size',
34                     type = int, default = 25)
35
36 parser.add_argument('--data',
37                     type = str, default = 'wiki103')
38
39 parser.add_argument('--data_size',
40                     type = int, default = -1)
41
42 parser.add_argument('--optim',
43                     type = str, default = 'adam')
44
45 parser.add_argument('--learning_rate',
46                     type = float, default = 1e-4)
47
48 parser.add_argument('--dim_model',
49                     type = int, default = 512)
50
51 parser.add_argument('--dim_keys',
52                     type = int, default = 64)
53
54 parser.add_argument('--dim_hidden',
55                     type = int, default = 2048)
56
57 parser.add_argument('--nb_heads',
58                     type = int, default = 8)
59
60 parser.add_argument('--nb_blocks',
61                     type = int, default = 12)
62
63 parser.add_argument('--dropout',
64                     type = float, default = 0.1)
65
66 parser.add_argument('--synthesis_sampling',
67                     action='store_true', default = True)
68
69 parser.add_argument('--no_checkpoint',
70                     action='store_true', default = False)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_nb_colors',
79                     type = int, default = 5)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 def autoregression(
114         model, batch_size,
115         nb_samples, nb_tokens_to_generate, starting_input = None,
116         device = torch.device('cpu')
117 ):
118     results = torch.zeros(
119         nb_samples, nb_tokens_to_generate,
120         dtype = torch.int64, device = device
121     )
122
123     if starting_input is None:
124         first = 0
125     else:
126         first = starting_input.size(1)
127         results = torch.cat((starting_input, results), 1)
128
129     for input in results.split(batch_size):
130         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131             output = model(input)
132             logits = output[:, s]
133             if args.synthesis_sampling:
134                 dist = torch.distributions.categorical.Categorical(logits = logits)
135                 t_next = dist.sample()
136             else:
137                 t_next = logits.argmax(1)
138             input[:, s] = t_next
139
140     return results
141
142 ######################################################################
143
144 class Task:
145     def batches(self, split = 'train'):
146         pass
147
148     def vocabulary_size(self):
149         pass
150
151     def produce_results(self, n_epoch, model, nb_tokens = 50):
152         pass
153
154 ######################################################################
155
156 import picoclvr
157
158 class TaskPicoCLVR(Task):
159
160     def __init__(self, batch_size,
161                  height, width, nb_colors = 5,
162                  device = torch.device('cpu')):
163
164         def generate_descr(nb):
165             descr = picoclvr.generate(
166                 nb,
167                 height = self.height, width = self.width,
168                 nb_colors = nb_colors
169             )
170
171             descr = [ s.strip().split(' ') for s in descr ]
172             l = max([ len(s) for s in descr ])
173             #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
174             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
175
176             return descr
177
178         self.height = height
179         self.width = width
180         self.batch_size = batch_size
181         self.device = device
182         nb = args.data_size if args.data_size > 0 else 250000
183
184         self.train_descr = generate_descr((nb * 4) // 5)
185         self.test_descr = generate_descr((nb * 1) // 5)
186
187         # Build the tokenizer
188         tokens = set()
189         for d in [ self.train_descr, self.test_descr ]:
190             for s in d:
191                 for t in s: tokens.add(t)
192         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
193         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
194
195         # Tokenize the train and test sets
196         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
197         self.train_input = torch.tensor(t, device = self.device)
198         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
199         self.test_input = torch.tensor(t, device = self.device)
200
201     def batches(self, split = 'train'):
202         assert split in { 'train', 'test' }
203         if split == 'train':
204             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
205                 yield batch
206         else:
207             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
208                 yield batch
209
210     def vocabulary_size(self):
211         return len(self.token2id)
212
213     def generate(self, primer, model, nb_tokens):
214         t_primer = primer.strip().split(' ')
215         t_generated = [ ]
216
217         for j in range(nb_tokens):
218             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
219             input = torch.tensor(t, device = self.device)
220             input = F.pad(input, (0, 1)) # Add the next token, the one to predict
221             output = model(input)
222             logits = output[0, -1]
223             if args.synthesis_sampling:
224                 dist = torch.distributions.categorical.Categorical(logits = logits)
225                 t_next = dist.sample()
226             else:
227                 t_next = logits.argmax()
228             t_generated.append(self.id2token[t_next.item()])
229
230         return ' '.join(t_primer + t_generated)
231
232     def produce_results(self, n_epoch, model, nb_tokens = None):
233         if nb_tokens is None:
234             nb_tokens = self.height * self.width + 3
235         descr = [ ]
236         nb_per_primer = 8
237
238         for primer in [
239                 'red above green <sep> green top <sep> blue right of red <img>',
240                 'there is red <sep> there is yellow <sep> there is blue <img>',
241                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
242                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
243         ]:
244
245             for k in range(nb_per_primer):
246                 descr.append(self.generate(primer, model, nb_tokens))
247
248         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
249         img = torch.cat(img, 0)
250         image_name = f'result_picoclvr_{n_epoch:04d}.png'
251         torchvision.utils.save_image(
252             img / 255.,
253             image_name, nrow = nb_per_primer, pad_value = 0.8
254         )
255         log_string(f'wrote {image_name}')
256
257         np = picoclvr.nb_properties(
258             descr,
259             height = self.height, width = self.width
260         )
261
262         nb_requested_properties, _, nb_missing_properties = zip(*np)
263
264         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(descr):.02f}')
265
266 ######################################################################
267
268 class TaskWiki103(Task):
269
270     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
271                  device = torch.device('cpu')):
272
273         self.batch_size = batch_size
274         self.len_min = len_min
275         self.len_max = len_max
276         self.min_freq = min_freq
277         self.device = device
278
279         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
280         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
281
282         # Mostly for debug
283         if args.data_size > 0:
284             train_iter = itertools.islice(train_iter, args.data_size)
285
286         def yield_tokens():
287             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
288                 yield self.tokenizer(l)
289
290         self.vocab = torchtext.vocab.build_vocab_from_iterator(
291             yield_tokens(),
292             specials = [ '<unk>', '<non>' ],
293             min_freq = self.min_freq
294         )
295
296         self.vocab.set_default_index(self.vocab[ '<unk>' ])
297
298     def tensorize(self, s):
299         a = max(len(x) for x in s)
300         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
301
302     def yield_batches(self, ds):
303         s = [ ]
304         for l in ds:
305             q = self.tokenizer(l)
306             if len(q) >= self.len_min and len(q) <= self.len_max:
307                 s += [ q ]
308                 if len(s) == self.batch_size:
309                     yield self.tensorize(s)
310                     s = [ ]
311
312         if len(s) > 0:
313             yield self.tensorize(s)
314
315     def batches(self, split = 'train'):
316         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
317
318         # Mostly for debug
319         if args.data_size > 0:
320             data_iter = itertools.islice(data_iter, args.data_size)
321
322         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
323
324     def vocabulary_size(self):
325         return len(self.vocab)
326
327     def produce_results(self, n_epoch, model, nb_tokens = 50):
328         file_name = f'result_wiki103_{n_epoch:04d}.txt'
329
330         with open(file_name, 'w') as outfile:
331              for primer in [
332                      'the cat is hunting a',
333                      'paris is the capital',
334                      'cars are convenient',
335                      'the difference between men and women is',
336                      'the object was blue all over and green all over it was',
337                      'cherries are red and lemons are',
338                      'cherries are sweet and lemons are',
339                      'two plus three equals',
340                      'deep learning is',
341              ]:
342                  t_primer = self.tokenizer(primer)
343                  t_generated = [ ]
344
345                  for j in range(nb_tokens):
346
347                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
348                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
349                      output = model(input)
350                      logits = output[0, -1]
351                      if args.synthesis_sampling:
352                          dist = torch.distributions.categorical.Categorical(logits = logits)
353                          t_next = dist.sample()
354                      else:
355                          t_next = logits.argmax()
356                      t_generated.append(self.vocab.lookup_token(t_next))
357                      if t_generated[-1] == '<non>': break
358
359                  s = ' '.join(t_generated)
360
361                  outfile.write(f'<{primer}> {s}\n')
362
363         log_string(f'wrote {file_name}')
364
365 ######################################################################
366
367 class TaskMNIST(Task):
368
369     def __init__(self, batch_size, device = torch.device('cpu')):
370         self.device = device
371         self.batch_size = batch_size
372
373     def batches(self, split = 'train'):
374         assert split in { 'train', 'test' }
375         data_set = torchvision.datasets.MNIST(
376             root = './data', train = (split == 'train'),
377             download = True
378         )
379         data_input = data_set.data.view(-1, 28 * 28).long()
380         if args.data_size >= 0:
381             data_input = data_input[:args.data_size]
382         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
383             yield batch
384
385     def vocabulary_size(self):
386         return 256
387
388     def produce_results(self, n_epoch, model, nb_samples = 64):
389         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
390         image_name = f'result_mnist_{n_epoch:04d}.png'
391         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392                                      image_name, nrow = 16, pad_value = 0.8)
393         log_string(f'wrote {image_name}')
394
395 ######################################################################
396
397 log_string(f'device {device}')
398
399 if args.data == 'wiki103':
400     nb_epochs_default = 10
401     task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403     nb_epochs_default = 25
404     task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406     nb_epochs_default = 10
407     task = TaskPicoCLVR(batch_size = args.batch_size,
408                         height = args.picoclvr_height,
409                         width = args.picoclvr_width,
410                         nb_colors = args.picoclvr_nb_colors,
411                         device = device)
412 else:
413     raise ValueError(f'Unknown dataset {args.data}.')
414
415 vocabulary_size = task.vocabulary_size()
416
417 log_string(f'vocabulary_size {vocabulary_size}')
418
419 ##############################
420
421 model = mygpt.MyGPT(
422     vocabulary_size = vocabulary_size,
423     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
425 )
426
427 model.to(device)
428
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
431
432 ######################################################################
433
434 if args.optim == 'sgd':
435     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
440 else:
441     raise ValueError(f'Unknown optimizer {args.optim}.')
442
443 ######################################################################
444
445 nb_epochs_finished = 0
446
447 if args.no_checkpoint:
448     log_string(f'not trying to load checkpoint.')
449
450 else:
451     try:
452         checkpoint = torch.load(args.checkpoint_name, map_location = device)
453         nb_epochs_finished = checkpoint['nb_epochs_finished']
454         model.load_state_dict(checkpoint['model_state'])
455         optimizer.load_state_dict(checkpoint['optimizer_state'])
456         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
457
458     except FileNotFoundError:
459         log_string('starting from scratch.')
460
461     except:
462         log_string('error when loading the checkpoint.')
463         exit(1)
464
465 ######################################################################
466
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468
469 token_count = 0
470 for input in task.batches(split = 'train'):
471     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'train set perplexity {train_set_perplexity}')
476
477 for k in range(nb_epochs_finished, nb_epochs):
478
479     model.train()
480
481     nb_train_samples, acc_train_loss = 0, 0.0
482
483     for input in task.batches(split = 'train'):
484         input = input.to(device)
485         output = model(input)
486         loss = F.cross_entropy(output.transpose(1, 2), input)
487         acc_train_loss += loss.item() * input.size(0)
488         nb_train_samples += input.size(0)
489
490         optimizer.zero_grad()
491         loss.backward()
492         optimizer.step()
493
494     with torch.autograd.no_grad():
495
496         model.eval()
497
498         nb_test_samples, acc_test_loss = 0, 0.0
499
500         for input in task.batches(split = 'test'):
501             input = input.to(device)
502             output = model(input)
503             loss = F.cross_entropy(output.transpose(1, 2), input)
504             acc_test_loss += loss.item() * input.size(0)
505             nb_test_samples += input.size(0)
506
507         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
509
510         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
511
512         task.produce_results(k, model)
513
514     checkpoint = {
515         'nb_epochs_finished': k + 1,
516         'model_state': model.state_dict(),
517         'optimizer_state': optimizer.state_dict()
518     }
519
520     torch.save(checkpoint, args.checkpoint_name)
521
522 ######################################################################