Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--seed',
28                     type = int, default = 0)
29
30 parser.add_argument('--nb_epochs',
31                     type = int, default = -1)
32
33 parser.add_argument('--batch_size',
34                     type = int, default = 25)
35
36 parser.add_argument('--data',
37                     type = str, default = 'wiki103')
38
39 parser.add_argument('--data_size',
40                     type = int, default = -1)
41
42 parser.add_argument('--optim',
43                     type = str, default = 'adam')
44
45 parser.add_argument('--learning_rate',
46                     type = float, default = 1e-4)
47
48 parser.add_argument('--dim_model',
49                     type = int, default = 512)
50
51 parser.add_argument('--dim_keys',
52                     type = int, default = 64)
53
54 parser.add_argument('--dim_hidden',
55                     type = int, default = 2048)
56
57 parser.add_argument('--nb_heads',
58                     type = int, default = 8)
59
60 parser.add_argument('--nb_blocks',
61                     type = int, default = 12)
62
63 parser.add_argument('--dropout',
64                     type = float, default = 0.1)
65
66 parser.add_argument('--synthesis_sampling',
67                     action='store_true', default = True)
68
69 parser.add_argument('--no_checkpoint',
70                     action='store_true', default = False)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_nb_colors',
79                     type = int, default = 5)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 def autoregression(
114         model,
115         nb_samples, nb_tokens_to_generate, starting_input = None,
116         device = torch.device('cpu')
117 ):
118     results = torch.zeros(
119         nb_samples, nb_tokens_to_generate,
120         dtype = torch.int64, device = device
121     )
122
123     if starting_input is None:
124         first = 0
125     else:
126         first = starting_input.size(1)
127         results = torch.cat((starting_input, results), 1)
128
129     for input in results.split(args.batch_size):
130         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131             output = model(input)
132             logits = output[:, s]
133             if args.synthesis_sampling:
134                 dist = torch.distributions.categorical.Categorical(logits = logits)
135                 t_next = dist.sample()
136             else:
137                 t_next = logits.argmax(1)
138             input[:, s] = t_next
139
140     return results
141
142 ######################################################################
143
144 class Task:
145     def batches(self, split = 'train'):
146         pass
147
148     def vocabulary_size(self):
149         pass
150
151     def produce_results(self, n_epoch, model, nb_tokens = 50):
152         pass
153
154 ######################################################################
155
156 import picoclvr
157
158 class TaskPicoCLVR(Task):
159
160     def __init__(self, batch_size,
161                  height, width, nb_colors = 5,
162                  device = torch.device('cpu')):
163
164         def generate_descr(nb):
165             descr = picoclvr.generate(
166                 nb,
167                 height = self.height, width = self.width,
168                 nb_colors = nb_colors
169             )
170
171             descr = [ s.strip().split(' ') for s in descr ]
172             l = max([ len(s) for s in descr ])
173             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
174
175             return descr
176
177         self.height = height
178         self.width = width
179         self.batch_size = batch_size
180         self.device = device
181         nb = args.data_size if args.data_size > 0 else 250000
182
183         self.train_descr = generate_descr((nb * 4) // 5)
184         self.test_descr = generate_descr((nb * 1) // 5)
185
186         # Build the tokenizer
187         tokens = set()
188         for d in [ self.train_descr, self.test_descr ]:
189             for s in d:
190                 for t in s: tokens.add(t)
191         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
192         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
193
194         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
195         self.train_input = torch.tensor(t, device = self.device)
196         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
197         self.test_input = torch.tensor(t, device = self.device)
198
199     def batches(self, split = 'train'):
200         assert split in { 'train', 'test' }
201         if split == 'train':
202             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
203                 yield batch
204         else:
205             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
206                 yield batch
207
208     def vocabulary_size(self):
209         return len(self.token2id)
210
211     def generate(self, primer, model, nb_tokens):
212         t_primer = primer.strip().split(' ')
213         t_generated = [ ]
214
215         for j in range(nb_tokens):
216             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
217             input = torch.tensor(t, device = self.device)
218             input = F.pad(input, (0, 1)) # Add the next token, the one to predict
219             output = model(input)
220             logits = output[0, -1]
221             if args.synthesis_sampling:
222                 dist = torch.distributions.categorical.Categorical(logits = logits)
223                 t_next = dist.sample()
224             else:
225                 t_next = logits.argmax()
226             t_generated.append(self.id2token[t_next.item()])
227
228         return ' '.join(t_primer + t_generated)
229
230     def produce_results(self, n_epoch, model, nb_tokens = None):
231         if nb_tokens is None:
232             nb_tokens = self.height * self.width + 3
233         descr = [ ]
234         nb_per_primer = 8
235
236         for primer in [
237                 'red above green <sep> green top <sep> blue right of red <img>',
238                 'there is red <sep> there is yellow <sep> there is blue <img>',
239                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
240                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
241         ]:
242
243             for k in range(nb_per_primer):
244                 descr.append(self.generate(primer, model, nb_tokens))
245
246         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
247         img = torch.cat(img, 0)
248         image_name = f'result_picoclvr_{n_epoch:04d}.png'
249         torchvision.utils.save_image(
250             img / 255.,
251             image_name, nrow = nb_per_primer, pad_value = 0.8
252         )
253         log_string(f'wrote {image_name}')
254
255         nb_missing = sum( [
256             x[2] for x in picoclvr.nb_missing_properties(
257                 descr,
258                 height = self.height, width = self.width
259             )
260         ] )
261
262         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
263
264 ######################################################################
265
266 class TaskWiki103(Task):
267
268     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
269                  device = torch.device('cpu')):
270
271         self.batch_size = batch_size
272         self.len_min = len_min
273         self.len_max = len_max
274         self.min_freq = min_freq
275         self.device = device
276
277         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
278         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
279
280         # Mostly for debug
281         if args.data_size > 0:
282             train_iter = itertools.islice(train_iter, args.data_size)
283
284         def yield_tokens():
285             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
286                 yield self.tokenizer(l)
287
288         self.vocab = torchtext.vocab.build_vocab_from_iterator(
289             yield_tokens(),
290             specials = [ '<unk>', '<non>' ],
291             min_freq = self.min_freq
292         )
293
294         self.vocab.set_default_index(self.vocab[ '<unk>' ])
295
296     def tensorize(self, s):
297         a = max(len(x) for x in s)
298         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
299
300     def yield_batches(self, ds):
301         s = [ ]
302         for l in ds:
303             q = self.tokenizer(l)
304             if len(q) >= self.len_min and len(q) <= self.len_max:
305                 s += [ q ]
306                 if len(s) == self.batch_size:
307                     yield self.tensorize(s)
308                     s = [ ]
309
310         if len(s) > 0:
311             yield self.tensorize(s)
312
313     def batches(self, split = 'train'):
314         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
315
316         # Mostly for debug
317         if args.data_size > 0:
318             data_iter = itertools.islice(data_iter, args.data_size)
319
320         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
321
322     def vocabulary_size(self):
323         return len(self.vocab)
324
325     def produce_results(self, n_epoch, model, nb_tokens = 50):
326         file_name = f'result_wiki103_{n_epoch:04d}.txt'
327
328         with open(file_name, 'w') as outfile:
329              for primer in [
330                      'the cat is hunting a',
331                      'paris is the capital',
332                      'cars are convenient',
333                      'the difference between men and women is',
334                      'the object was blue all over and green all over it was',
335                      'cherries are red and lemons are',
336                      'cherries are sweet and lemons are',
337                      'two plus three equals',
338                      'deep learning is',
339              ]:
340                  t_primer = self.tokenizer(primer)
341                  t_generated = [ ]
342
343                  for j in range(nb_tokens):
344
345                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
346                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
347                      output = model(input)
348                      logits = output[0, -1]
349                      if args.synthesis_sampling:
350                          dist = torch.distributions.categorical.Categorical(logits = logits)
351                          t_next = dist.sample()
352                      else:
353                          t_next = logits.argmax()
354                      t_generated.append(self.vocab.lookup_token(t_next))
355                      if t_generated[-1] == '<non>': break
356
357                  s = ' '.join(t_generated)
358
359                  outfile.write(f'<{primer}> {s}\n')
360
361         log_string(f'wrote {file_name}')
362
363 ######################################################################
364
365 class TaskMNIST(Task):
366
367     def __init__(self, batch_size, device = torch.device('cpu')):
368         self.device = device
369         self.batch_size = batch_size
370
371     def batches(self, split = 'train'):
372         assert split in { 'train', 'test' }
373         data_set = torchvision.datasets.MNIST(
374             root = './data', train = (split == 'train'),
375             download = True
376         )
377         data_input = data_set.data.view(-1, 28 * 28).long()
378         if args.data_size >= 0:
379             data_input = data_input[:args.data_size]
380         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
381             yield batch
382
383     def vocabulary_size(self):
384         return 256
385
386     def produce_results(self, n_epoch, model, nb_samples = 64):
387         results = autoregression(model, nb_samples, 28 * 28, device = self.device)
388         image_name = f'result_mnist_{n_epoch:04d}.png'
389         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390                                      image_name, nrow = 16, pad_value = 0.8)
391         log_string(f'wrote {image_name}')
392
393 ######################################################################
394
395 log_string(f'device {device}')
396
397 if args.data == 'wiki103':
398     nb_epochs_default = 10
399     task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401     nb_epochs_default = 25
402     task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404     nb_epochs_default = 10
405     task = TaskPicoCLVR(batch_size = args.batch_size,
406                         height = args.picoclvr_height,
407                         width = args.picoclvr_width,
408                         nb_colors = args.picoclvr_nb_colors,
409                         device = device)
410 else:
411     raise ValueError(f'Unknown dataset {args.data}.')
412
413 vocabulary_size = task.vocabulary_size()
414
415 log_string(f'vocabulary_size {vocabulary_size}')
416
417 ##############################
418
419 model = mygpt.MyGPT(
420     vocabulary_size = vocabulary_size,
421     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
423 )
424
425 model.to(device)
426
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
429
430 ######################################################################
431
432 if args.optim == 'sgd':
433     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
438 else:
439     raise ValueError(f'Unknown optimizer {args.optim}.')
440
441 ######################################################################
442
443 nb_epochs_finished = 0
444
445 if args.no_checkpoint:
446     log_string(f'not trying to load checkpoint.')
447
448 else:
449     try:
450         checkpoint = torch.load(args.checkpoint_name, map_location = device)
451         nb_epochs_finished = checkpoint['nb_epochs_finished']
452         model.load_state_dict(checkpoint['model_state'])
453         optimizer.load_state_dict(checkpoint['optimizer_state'])
454         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
455
456     except FileNotFoundError:
457         log_string('starting from scratch.')
458
459     except:
460         log_string('error when loading the checkpoint.')
461         exit(1)
462
463 ######################################################################
464
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
466
467 token_count = 0
468 for input in task.batches(split = 'train'):
469     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 h = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(h)
473 log_string(f'train set perplexity {train_set_perplexity}')
474
475 for k in range(nb_epochs_finished, nb_epochs):
476
477     model.train()
478
479     nb_train_samples, acc_train_loss = 0, 0.0
480
481     for input in task.batches(split = 'train'):
482         input = input.to(device)
483         output = model(input)
484         loss = F.cross_entropy(output.transpose(1, 2), input)
485         acc_train_loss += loss.item() * input.size(0)
486         nb_train_samples += input.size(0)
487
488         optimizer.zero_grad()
489         loss.backward()
490         optimizer.step()
491
492     with torch.autograd.no_grad():
493
494         model.eval()
495
496         nb_test_samples, acc_test_loss = 0, 0.0
497
498         for input in task.batches(split = 'test'):
499             input = input.to(device)
500             output = model(input)
501             loss = F.cross_entropy(output.transpose(1, 2), input)
502             acc_test_loss += loss.item() * input.size(0)
503             nb_test_samples += input.size(0)
504
505         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
507
508         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
509
510         task.produce_results(k, model)
511
512     checkpoint = {
513         'nb_epochs_finished': k + 1,
514         'model_state': model.state_dict(),
515         'optimizer_state': optimizer.state_dict()
516     }
517
518     torch.save(checkpoint, args.checkpoint_name)
519
520 ######################################################################