Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = -1)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_nb_colors',
82                     type = int, default = 5)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 def autoregression(
117         model,
118         nb_samples, nb_tokens_to_generate, starting_input = None,
119         device = torch.device('cpu')
120 ):
121     results = torch.zeros(
122         nb_samples, nb_tokens_to_generate,
123         dtype = torch.int64, device = device
124     )
125
126     if starting_input is None:
127         first = 0
128     else:
129         first = starting_input.size(1)
130         results = torch.cat((starting_input, results), 1)
131
132     for input in results.split(args.batch_size):
133         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
134             output = model(input)
135             logits = output[:, s]
136             if args.synthesis_sampling:
137                 dist = torch.distributions.categorical.Categorical(logits = logits)
138                 t_next = dist.sample()
139             else:
140                 t_next = logits.argmax(1)
141             input[:, s] = t_next
142
143     return results
144
145 ######################################################################
146
147 class Task:
148     def batches(self, split = 'train'):
149         pass
150
151     def vocabulary_size(self):
152         pass
153
154     def produce_results(self, n_epoch, model, nb_tokens = 50):
155         pass
156
157 ######################################################################
158
159 import picoclvr
160
161 class TaskPicoCLVR(Task):
162
163     def __init__(self, batch_size,
164                  height, width, nb_colors = 5,
165                  device = torch.device('cpu')):
166
167         def generate_descr(nb):
168             descr = picoclvr.generate(
169                 nb,
170                 height = self.height, width = self.width,
171                 nb_colors = nb_colors
172             )
173
174             descr = [ s.strip().split(' ') for s in descr ]
175             l = max([ len(s) for s in descr ])
176             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
177
178             return descr
179
180         self.height = height
181         self.width = width
182         self.batch_size = batch_size
183         self.device = device
184         nb = args.data_size if args.data_size > 0 else 250000
185
186         self.train_descr = generate_descr((nb * 4) // 5)
187         self.test_descr = generate_descr((nb * 1) // 5)
188
189         # Build the tokenizer
190         tokens = set()
191         for d in [ self.train_descr, self.test_descr ]:
192             for s in d:
193                 for t in s: tokens.add(t)
194         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
195         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
196
197         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
198         self.train_input = torch.tensor(t, device = self.device)
199         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
200         self.test_input = torch.tensor(t, device = self.device)
201
202     def batches(self, split = 'train'):
203         assert split in { 'train', 'test' }
204         if split == 'train':
205             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
206                 yield batch
207         else:
208             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
209                 yield batch
210
211     def vocabulary_size(self):
212         return len(self.token2id)
213
214     def generate(self, primer, model, nb_tokens):
215         t_primer = primer.strip().split(' ')
216         t_generated = [ ]
217
218         for j in range(nb_tokens):
219             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
220             input = torch.tensor(t, device = self.device)
221             input = F.pad(input, (0, 1)) # Add the next token, the one to predict
222             output = model(input)
223             logits = output[0, -1]
224             if args.synthesis_sampling:
225                 dist = torch.distributions.categorical.Categorical(logits = logits)
226                 t_next = dist.sample()
227             else:
228                 t_next = logits.argmax()
229             t_generated.append(self.id2token[t_next.item()])
230
231         return ' '.join(t_primer + t_generated)
232
233     def produce_results(self, n_epoch, model, nb_tokens = None):
234         if nb_tokens is None:
235             nb_tokens = self.height * self.width + 3
236         descr = [ ]
237         nb_per_primer = 8
238
239         for primer in [
240                 'red above green <sep> green top <sep> blue right of red <img>',
241                 'there is red <sep> there is yellow <sep> there is blue <img>',
242                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
243                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
244         ]:
245
246             for k in range(nb_per_primer):
247                 descr.append(self.generate(primer, model, nb_tokens))
248
249         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
250         img = torch.cat(img, 0)
251         image_name = f'result_picoclvr_{n_epoch:04d}.png'
252         torchvision.utils.save_image(
253             img / 255.,
254             image_name, nrow = nb_per_primer, pad_value = 0.8
255         )
256         log_string(f'wrote {image_name}')
257
258         nb_missing = sum( [
259             x[2] for x in picoclvr.nb_missing_properties(
260                 descr,
261                 height = self.height, width = self.width
262             )
263         ] )
264
265         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
266
267 ######################################################################
268
269 class TaskWiki103(Task):
270
271     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
272                  device = torch.device('cpu')):
273
274         self.batch_size = batch_size
275         self.len_min = len_min
276         self.len_max = len_max
277         self.min_freq = min_freq
278         self.device = device
279
280         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
281         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
282
283         # Mostly for debug
284         if args.data_size > 0:
285             train_iter = itertools.islice(train_iter, args.data_size)
286
287         def yield_tokens():
288             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
289                 yield self.tokenizer(l)
290
291         self.vocab = torchtext.vocab.build_vocab_from_iterator(
292             yield_tokens(),
293             specials = [ '<unk>', '<non>' ],
294             min_freq = self.min_freq
295         )
296
297         self.vocab.set_default_index(self.vocab[ '<unk>' ])
298
299     def tensorize(self, s):
300         a = max(len(x) for x in s)
301         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
302
303     def yield_batches(self, ds):
304         s = [ ]
305         for l in ds:
306             q = self.tokenizer(l)
307             if len(q) >= self.len_min and len(q) <= self.len_max:
308                 s += [ q ]
309                 if len(s) == self.batch_size:
310                     yield self.tensorize(s)
311                     s = [ ]
312
313         if len(s) > 0:
314             yield self.tensorize(s)
315
316     def batches(self, split = 'train'):
317         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
318
319         # Mostly for debug
320         if args.data_size > 0:
321             data_iter = itertools.islice(data_iter, args.data_size)
322
323         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
324
325     def vocabulary_size(self):
326         return len(self.vocab)
327
328     def produce_results(self, n_epoch, model, nb_tokens = 50):
329         file_name = f'result_wiki103_{n_epoch:04d}.txt'
330
331         with open(file_name, 'w') as outfile:
332              for primer in [
333                      'the cat is hunting a',
334                      'paris is the capital',
335                      'cars are convenient',
336                      'the difference between men and women is',
337                      'the object was blue all over and green all over it was',
338                      'cherries are red and lemons are',
339                      'cherries are sweet and lemons are',
340                      'two plus three equals',
341                      'deep learning is',
342              ]:
343                  t_primer = self.tokenizer(primer)
344                  t_generated = [ ]
345
346                  for j in range(nb_tokens):
347
348                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
349                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
350                      output = model(input)
351                      logits = output[0, -1]
352                      if args.synthesis_sampling:
353                          dist = torch.distributions.categorical.Categorical(logits = logits)
354                          t_next = dist.sample()
355                      else:
356                          t_next = logits.argmax()
357                      t_generated.append(self.vocab.lookup_token(t_next))
358                      if t_generated[-1] == '<non>': break
359
360                  s = ' '.join(t_generated)
361
362                  outfile.write(f'<{primer}> {s}\n')
363
364         log_string(f'wrote {file_name}')
365
366 ######################################################################
367
368 class TaskMNIST(Task):
369
370     def __init__(self, batch_size, device = torch.device('cpu')):
371         self.device = device
372         self.batch_size = batch_size
373
374     def batches(self, split = 'train'):
375         assert split in { 'train', 'test' }
376         data_set = torchvision.datasets.MNIST(
377             root = './data', train = (split == 'train'),
378             download = True
379         )
380         data_input = data_set.data.view(-1, 28 * 28).long()
381         if args.data_size >= 0:
382             data_input = data_input[:args.data_size]
383         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
384             yield batch
385
386     def vocabulary_size(self):
387         return 256
388
389     def produce_results(self, n_epoch, model, nb_samples = 64):
390         results = autoregression(model, nb_samples, 28 * 28, device = self.device)
391         image_name = f'result_mnist_{n_epoch:04d}.png'
392         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
393                                      image_name, nrow = 16, pad_value = 0.8)
394         log_string(f'wrote {image_name}')
395
396 ######################################################################
397
398 log_string(f'device {device}')
399
400 if args.data == 'wiki103':
401     nb_epochs_default = 10
402     task = TaskWiki103(batch_size = args.batch_size, device = device)
403 elif args.data == 'mnist':
404     nb_epochs_default = 25
405     task = TaskMNIST(batch_size = args.batch_size, device = device)
406 elif args.data == 'picoclvr':
407     nb_epochs_default = 10
408     task = TaskPicoCLVR(batch_size = args.batch_size,
409                         height = args.picoclvr_height,
410                         width = args.picoclvr_width,
411                         nb_colors = args.picoclvr_nb_colors,
412                         device = device)
413 else:
414     raise ValueError(f'Unknown dataset {args.data}.')
415
416 vocabulary_size = task.vocabulary_size()
417
418 log_string(f'vocabulary_size {vocabulary_size}')
419
420 ##############################
421
422 model = mygpt.MyGPT(
423     vocabulary_size = vocabulary_size,
424     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
425     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
426 )
427
428 model.to(device)
429
430 nb_parameters = sum(p.numel() for p in model.parameters())
431 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
432
433 ######################################################################
434
435 if args.optim == 'sgd':
436     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
437 elif args.optim == 'adam':
438     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
439 elif args.optim == 'adamw':
440     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
441 else:
442     raise ValueError(f'Unknown optimizer {args.optim}.')
443
444 ######################################################################
445
446 nb_epochs_finished = 0
447
448 if args.no_checkpoint:
449     log_string(f'Not trying to load checkpoint.')
450
451 else:
452     try:
453         checkpoint = torch.load(args.checkpoint_name, map_location = device)
454         nb_epochs_finished = checkpoint['nb_epochs_finished']
455         model.load_state_dict(checkpoint['model_state'])
456         optimizer.load_state_dict(checkpoint['optimizer_state'])
457         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
458
459     except FileNotFoundError:
460         log_string('Starting from scratch.')
461
462     except:
463         log_string('Error when loading the checkpoint.')
464         exit(1)
465
466 ######################################################################
467
468 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
469
470 token_count = 0
471 for input in task.batches(split = 'train'):
472     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
473 token_probas = token_count / token_count.sum()
474 h = -torch.xlogy(token_probas, token_probas).sum()
475 train_set_perplexity = math.exp(h)
476 log_string(f'Train set perplexity {train_set_perplexity}')
477
478 for k in range(nb_epochs_finished, nb_epochs):
479
480     model.train()
481
482     nb_train_samples, acc_train_loss = 0, 0.0
483
484     for input in task.batches(split = 'train'):
485         input = input.to(device)
486         output = model(input)
487         loss = F.cross_entropy(output.transpose(1, 2), input)
488         acc_train_loss += loss.item() * input.size(0)
489         nb_train_samples += input.size(0)
490
491         optimizer.zero_grad()
492         loss.backward()
493         optimizer.step()
494
495     with torch.autograd.no_grad():
496
497         model.eval()
498
499         nb_test_samples, acc_test_loss = 0, 0.0
500
501         for input in task.batches(split = 'test'):
502             input = input.to(device)
503             output = model(input)
504             loss = F.cross_entropy(output.transpose(1, 2), input)
505             acc_test_loss += loss.item() * input.size(0)
506             nb_test_samples += input.size(0)
507
508         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
509         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
510
511         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
512
513         task.produce_results(k, model)
514
515     checkpoint = {
516         'nb_epochs_finished': k + 1,
517         'model_state': model.state_dict(),
518         'optimizer_state': optimizer.state_dict()
519     }
520
521     torch.save(checkpoint, args.checkpoint_name)
522
523 ######################################################################