c810eef06593c7939e0ad15fe690225509cd4150
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = -1)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_many_colors',
82                     action='store_true', default = False)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 def produce_results(
117         self,
118         model, nb_samples, nb_tokens_to_generate, starting_input = None,
119         device = 'cpu'
120 ):
121     results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
122     for input in results.split(self.batch_size):
123         for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
124             output = model(input)
125             logits = output[:, s]
126             if args.synthesis_sampling:
127                 dist = torch.distributions.categorical.Categorical(logits = logits)
128                 t = dist.sample()
129             else:
130                 t = logits.argmax(1)
131             input[:, s + 1] = t
132
133 ######################################################################
134
135 class Task:
136     def batches(self, split = 'train'):
137         pass
138
139     def vocabulary_size(self):
140         pass
141
142     def produce_results(self, n_epoch, model, nb_tokens = 50):
143         pass
144
145 ######################################################################
146
147 import picoclvr
148
149 class TaskPicoCLVR(Task):
150
151     def __init__(self, batch_size,
152                  height, width, many_colors = False,
153                  device = torch.device('cpu')):
154
155         def generate_descr(nb):
156             descr = picoclvr.generate(
157                 nb,
158                 height = self.height, width = self.width,
159                 many_colors = many_colors
160             )
161
162             descr = [ s.strip().split(' ') for s in descr ]
163             l = max([ len(s) for s in descr ])
164             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
165
166             return descr
167
168         self.height = height
169         self.width = width
170         self.batch_size = batch_size
171         self.device = device
172         nb = args.data_size if args.data_size > 0 else 250000
173
174         self.train_descr = generate_descr((nb * 4) // 5)
175         self.test_descr = generate_descr((nb * 1) // 5)
176
177         # Build the tokenizer
178         tokens = set()
179         for d in [ self.train_descr, self.test_descr ]:
180             for s in d:
181                 for t in s: tokens.add(t)
182         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
183         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
184
185         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
186         self.train_input = torch.tensor(t, device = self.device)
187         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
188         self.test_input = torch.tensor(t, device = self.device)
189
190     def batches(self, split = 'train'):
191         assert split in { 'train', 'test' }
192         if split == 'train':
193             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
194                 yield batch
195         else:
196             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
197                 yield batch
198
199     def vocabulary_size(self):
200         return len(self.token2id)
201
202     def generate(self, primer, model, nb_tokens):
203         t_primer = primer.strip().split(' ')
204         t_generated = [ ]
205
206         for j in range(nb_tokens):
207             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
208             input = torch.tensor(t, device = self.device)
209             output = model(input)
210             logits = output[0, -1]
211             if args.synthesis_sampling:
212                 dist = torch.distributions.categorical.Categorical(logits = logits)
213                 t = dist.sample()
214             else:
215                 t = logits.argmax()
216             t_generated.append(self.id2token[t.item()])
217
218         return ' '.join(t_primer + t_generated)
219
220     def produce_results(self, n_epoch, model, nb_tokens = None):
221         if nb_tokens is None:
222             nb_tokens = self.height * self.width + 3
223         descr = [ ]
224         nb_per_primer = 8
225
226         for primer in [
227                 'red above green <sep> green top <sep> blue right of red <img>',
228                 'there is red <sep> there is yellow <sep> there is blue <img>',
229                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
230                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
231         ]:
232
233             for k in range(nb_per_primer):
234                 descr.append(self.generate(primer, model, nb_tokens))
235
236         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
237         img = torch.cat(img, 0)
238         image_name = f'result_picoclvr_{n_epoch:04d}.png'
239         torchvision.utils.save_image(
240             img / 255.,
241             image_name, nrow = nb_per_primer, pad_value = 0.8
242         )
243         log_string(f'wrote {image_name}')
244
245         nb_missing = sum( [
246             x[2] for x in picoclvr.nb_missing_properties(
247                 descr,
248                 height = self.height, width = self.width
249             )
250         ] )
251
252         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
253
254 ######################################################################
255
256 class TaskWiki103(Task):
257
258     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
259                  device = torch.device('cpu')):
260
261         self.batch_size = batch_size
262         self.len_min = len_min
263         self.len_max = len_max
264         self.min_freq = min_freq
265         self.device = device
266
267         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
268         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
269
270         # Mostly for debug
271         if args.data_size > 0:
272             train_iter = itertools.islice(train_iter, args.data_size)
273
274         def yield_tokens():
275             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
276                 yield self.tokenizer(l)
277
278         self.vocab = torchtext.vocab.build_vocab_from_iterator(
279             yield_tokens(),
280             specials = [ '<unk>', '<non>' ],
281             min_freq = self.min_freq
282         )
283
284         self.vocab.set_default_index(self.vocab[ '<unk>' ])
285
286     def tensorize(self, s):
287         a = max(len(x) for x in s)
288         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
289
290     def yield_batches(self, ds):
291         s = [ ]
292         for l in ds:
293             q = self.tokenizer(l)
294             if len(q) >= self.len_min and len(q) <= self.len_max:
295                 s += [ q ]
296                 if len(s) == self.batch_size:
297                     yield self.tensorize(s)
298                     s = [ ]
299
300         if len(s) > 0:
301             yield self.tensorize(s)
302
303     def batches(self, split = 'train'):
304         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
305
306         # Mostly for debug
307         if args.data_size > 0:
308             data_iter = itertools.islice(data_iter, args.data_size)
309
310         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
311
312     def vocabulary_size(self):
313         return len(self.vocab)
314
315     def produce_results(self, n_epoch, model, nb_tokens = 50):
316         file_name = f'result_wiki103_{n_epoch:04d}.txt'
317
318         with open(file_name, 'w') as outfile:
319              for primer in [
320                      'the cat is hunting a',
321                      'paris is the capital',
322                      'cars are convenient',
323                      'the difference between men and women is',
324                      'the object was blue all over and green all over it was',
325                      'cherries are red and lemons are',
326                      'cherries are sweet and lemons are',
327                      'two plus three equals',
328                      'deep learning is',
329              ]:
330                  t_primer = self.tokenizer(primer)
331                  t_generated = [ ]
332
333                  for j in range(nb_tokens):
334
335                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
336                      output = model(input)
337                      logits = output[0, -1]
338                      if args.synthesis_sampling:
339                          dist = torch.distributions.categorical.Categorical(logits = logits)
340                          t = dist.sample()
341                      else:
342                          t = logits.argmax()
343                      t_generated.append(self.vocab.lookup_token(t))
344                      if t_generated[-1] == '<non>': break
345
346                  s = ' '.join(t_generated)
347
348                  outfile.write(f'<{primer}> {s}\n')
349
350         log_string(f'wrote {file_name}')
351
352 ######################################################################
353
354 class TaskMNIST(Task):
355
356     def __init__(self, batch_size, device = torch.device('cpu')):
357         self.device = device
358         self.batch_size = batch_size
359
360     def batches(self, split = 'train'):
361         assert split in { 'train', 'test' }
362         data_set = torchvision.datasets.MNIST(
363             root = './data', train = (split == 'train'),
364             download = True
365         )
366         data_input = data_set.data.view(-1, 28 * 28).long()
367         if args.data_size >= 0:
368             data_input = data_input[:args.data_size]
369         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
370             yield batch
371
372     def vocabulary_size(self):
373         return 256
374
375     def produce_results(self, n_epoch, model, nb_samples = 64):
376         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
377         for input in results.split(self.batch_size):
378             for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
379                 output = model(input)
380                 logits = output[:, s]
381                 if args.synthesis_sampling:
382                     dist = torch.distributions.categorical.Categorical(logits = logits)
383                     t = dist.sample()
384                 else:
385                     t = logits.argmax(1)
386                 input[:, s] = t
387
388         image_name = f'result_mnist_{n_epoch:04d}.png'
389         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390                                      image_name, nrow = 16, pad_value = 0.8)
391         log_string(f'wrote {image_name}')
392
393 ######################################################################
394
395 log_string(f'device {device}')
396
397 if args.data == 'wiki103':
398     nb_epochs_default = 10
399     task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401     nb_epochs_default = 25
402     task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404     nb_epochs_default = 10
405     task = TaskPicoCLVR(batch_size = args.batch_size,
406                         height = args.picoclvr_height,
407                         width = args.picoclvr_width,
408                         many_colors = args.picoclvr_many_colors,
409                         device = device)
410 else:
411     raise ValueError(f'Unknown dataset {args.data}.')
412
413 vocabulary_size = task.vocabulary_size()
414
415 log_string(f'vocabulary_size {vocabulary_size}')
416
417 ##############################
418
419 model = mygpt.MyGPT(
420     vocabulary_size = vocabulary_size,
421     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
423 )
424
425 model.to(device)
426
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
429
430 ######################################################################
431
432 if args.optim == 'sgd':
433     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
438 else:
439     raise ValueError(f'Unknown optimizer {args.optim}.')
440
441 ######################################################################
442
443 nb_epochs_finished = 0
444
445 if args.no_checkpoint:
446     log_string(f'Not trying to load checkpoint.')
447
448 else:
449     try:
450         checkpoint = torch.load(args.checkpoint_name, map_location = device)
451         nb_epochs_finished = checkpoint['nb_epochs_finished']
452         model.load_state_dict(checkpoint['model_state'])
453         optimizer.load_state_dict(checkpoint['optimizer_state'])
454         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
455
456     except FileNotFoundError:
457         log_string('Starting from scratch.')
458
459     except:
460         log_string('Error when loading the checkpoint.')
461         exit(1)
462
463 ######################################################################
464
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
466
467 token_count = 0
468 for input in task.batches(split = 'train'):
469     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 h = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(h)
473 log_string(f'Train set perplexity {train_set_perplexity}')
474
475 for k in range(nb_epochs_finished, nb_epochs):
476
477     model.train()
478
479     nb_train_samples, acc_train_loss = 0, 0.0
480
481     for input in task.batches(split = 'train'):
482         input = input.to(device)
483         output = model(input)
484         loss = F.cross_entropy(output.transpose(1, 2), input)
485         acc_train_loss += loss.item() * input.size(0)
486         nb_train_samples += input.size(0)
487
488         optimizer.zero_grad()
489         loss.backward()
490         optimizer.step()
491
492     with torch.autograd.no_grad():
493
494         model.eval()
495
496         nb_test_samples, acc_test_loss = 0, 0.0
497
498         for input in task.batches(split = 'test'):
499             input = input.to(device)
500             output = model(input)
501             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
502             acc_test_loss += loss.item() * input.size(0)
503             nb_test_samples += input.size(0)
504
505         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
507
508         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
509
510         task.produce_results(k, model)
511
512     checkpoint = {
513         'nb_epochs_finished': k + 1,
514         'model_state': model.state_dict(),
515         'optimizer_state': optimizer.state_dict()
516     }
517
518     torch.save(checkpoint, args.checkpoint_name)
519
520 ######################################################################