Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = -1)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--no_checkpoint',
73                     action='store_true', default = False)
74
75 parser.add_argument('--checkpoint_name',
76                     type = str, default = 'checkpoint.pth')
77
78 ##############################
79 # picoclvr options
80
81 parser.add_argument('--picoclvr_many_colors',
82                     action='store_true', default = False)
83
84 parser.add_argument('--picoclvr_height',
85                     type = int, default = 12)
86
87 parser.add_argument('--picoclvr_width',
88                     type = int, default = 16)
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 log_file = open(args.log_filename, 'w')
95
96 if args.seed >= 0:
97     torch.manual_seed(args.seed)
98
99 ######################################################################
100
101 def log_string(s):
102     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103
104     if log_file is not None:
105         log_file.write(t + s + '\n')
106         log_file.flush()
107
108     print(t + s)
109     sys.stdout.flush()
110
111 for n in vars(args):
112     log_string(f'args.{n} {getattr(args, n)}')
113
114 ######################################################################
115
116 def produce_results(
117         self,
118         model, nb_samples, nb_tokens_to_generate, starting_input = None,
119         device = 'cpu'
120 ):
121     results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
122     for input in results.split(self.batch_size):
123         for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
124             output = model(input)
125             logits = output[:, s]
126             if args.synthesis_sampling:
127                 dist = torch.distributions.categorical.Categorical(logits = logits)
128                 t = dist.sample()
129             else:
130                 t = logits.argmax(1)
131             input[:, s + 1] = t
132
133 ######################################################################
134
135 class Task:
136     def batches(self, split = 'train'):
137         pass
138
139     def vocabulary_size(self):
140         pass
141
142     def produce_results(self, n_epoch, model, nb_tokens = 50):
143         pass
144
145 ######################################################################
146
147 import picoclvr
148
149 class TaskPicoCLVR(Task):
150
151     def __init__(self, batch_size,
152                  height, width, many_colors = False,
153                  device = torch.device('cpu')):
154
155         def generate_descr(nb):
156             descr = picoclvr.generate(
157                 nb,
158                 height = self.height, width = self.width,
159                 many_colors = many_colors
160             )
161
162             descr = [ s.strip().split(' ') for s in descr ]
163             l = max([ len(s) for s in descr ])
164             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
165
166             return descr
167
168         self.height = height
169         self.width = width
170         self.batch_size = batch_size
171         self.device = device
172         nb = args.data_size if args.data_size > 0 else 250000
173
174         self.train_descr = generate_descr((nb * 4) // 5)
175         self.test_descr = generate_descr((nb * 1) // 5)
176
177         # Build the tokenizer
178         tokens = set()
179         for d in [ self.train_descr, self.test_descr ]:
180             for s in d:
181                 for t in s: tokens.add(t)
182         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
183         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
184
185         t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
186         self.train_input = torch.tensor(t, device = self.device)
187         t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
188         self.test_input = torch.tensor(t, device = self.device)
189
190     def batches(self, split = 'train'):
191         assert split in { 'train', 'test' }
192         if split == 'train':
193             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
194                 yield batch
195         else:
196             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
197                 yield batch
198
199     def vocabulary_size(self):
200         return len(self.token2id)
201
202     def generate(self, primer, model, nb_tokens):
203         t_primer = primer.strip().split(' ')
204         t_generated = [ ]
205
206         for j in range(nb_tokens):
207             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
208             input = torch.tensor(t, device = self.device)
209             input = F.pad(input, (0, 1)) # Add the next token, the one to predict
210             output = model(input)
211             logits = output[0, -1]
212             if args.synthesis_sampling:
213                 dist = torch.distributions.categorical.Categorical(logits = logits)
214                 t = dist.sample()
215             else:
216                 t = logits.argmax()
217             t_generated.append(self.id2token[t.item()])
218
219         return ' '.join(t_primer + t_generated)
220
221     def produce_results(self, n_epoch, model, nb_tokens = None):
222         if nb_tokens is None:
223             nb_tokens = self.height * self.width + 3
224         descr = [ ]
225         nb_per_primer = 8
226
227         for primer in [
228                 'red above green <sep> green top <sep> blue right of red <img>',
229                 'there is red <sep> there is yellow <sep> there is blue <img>',
230                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
231                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
232         ]:
233
234             for k in range(nb_per_primer):
235                 descr.append(self.generate(primer, model, nb_tokens))
236
237         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
238         img = torch.cat(img, 0)
239         image_name = f'result_picoclvr_{n_epoch:04d}.png'
240         torchvision.utils.save_image(
241             img / 255.,
242             image_name, nrow = nb_per_primer, pad_value = 0.8
243         )
244         log_string(f'wrote {image_name}')
245
246         nb_missing = sum( [
247             x[2] for x in picoclvr.nb_missing_properties(
248                 descr,
249                 height = self.height, width = self.width
250             )
251         ] )
252
253         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
254
255 ######################################################################
256
257 class TaskWiki103(Task):
258
259     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
260                  device = torch.device('cpu')):
261
262         self.batch_size = batch_size
263         self.len_min = len_min
264         self.len_max = len_max
265         self.min_freq = min_freq
266         self.device = device
267
268         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
269         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
270
271         # Mostly for debug
272         if args.data_size > 0:
273             train_iter = itertools.islice(train_iter, args.data_size)
274
275         def yield_tokens():
276             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
277                 yield self.tokenizer(l)
278
279         self.vocab = torchtext.vocab.build_vocab_from_iterator(
280             yield_tokens(),
281             specials = [ '<unk>', '<non>' ],
282             min_freq = self.min_freq
283         )
284
285         self.vocab.set_default_index(self.vocab[ '<unk>' ])
286
287     def tensorize(self, s):
288         a = max(len(x) for x in s)
289         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
290
291     def yield_batches(self, ds):
292         s = [ ]
293         for l in ds:
294             q = self.tokenizer(l)
295             if len(q) >= self.len_min and len(q) <= self.len_max:
296                 s += [ q ]
297                 if len(s) == self.batch_size:
298                     yield self.tensorize(s)
299                     s = [ ]
300
301         if len(s) > 0:
302             yield self.tensorize(s)
303
304     def batches(self, split = 'train'):
305         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
306
307         # Mostly for debug
308         if args.data_size > 0:
309             data_iter = itertools.islice(data_iter, args.data_size)
310
311         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
312
313     def vocabulary_size(self):
314         return len(self.vocab)
315
316     def produce_results(self, n_epoch, model, nb_tokens = 50):
317         file_name = f'result_wiki103_{n_epoch:04d}.txt'
318
319         with open(file_name, 'w') as outfile:
320              for primer in [
321                      'the cat is hunting a',
322                      'paris is the capital',
323                      'cars are convenient',
324                      'the difference between men and women is',
325                      'the object was blue all over and green all over it was',
326                      'cherries are red and lemons are',
327                      'cherries are sweet and lemons are',
328                      'two plus three equals',
329                      'deep learning is',
330              ]:
331                  t_primer = self.tokenizer(primer)
332                  t_generated = [ ]
333
334                  for j in range(nb_tokens):
335
336                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
337                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
338                      output = model(input)
339                      logits = output[0, -1]
340                      if args.synthesis_sampling:
341                          dist = torch.distributions.categorical.Categorical(logits = logits)
342                          t = dist.sample()
343                      else:
344                          t = logits.argmax()
345                      t_generated.append(self.vocab.lookup_token(t))
346                      if t_generated[-1] == '<non>': break
347
348                  s = ' '.join(t_generated)
349
350                  outfile.write(f'<{primer}> {s}\n')
351
352         log_string(f'wrote {file_name}')
353
354 ######################################################################
355
356 class TaskMNIST(Task):
357
358     def __init__(self, batch_size, device = torch.device('cpu')):
359         self.device = device
360         self.batch_size = batch_size
361
362     def batches(self, split = 'train'):
363         assert split in { 'train', 'test' }
364         data_set = torchvision.datasets.MNIST(
365             root = './data', train = (split == 'train'),
366             download = True
367         )
368         data_input = data_set.data.view(-1, 28 * 28).long()
369         if args.data_size >= 0:
370             data_input = data_input[:args.data_size]
371         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
372             yield batch
373
374     def vocabulary_size(self):
375         return 256
376
377     def produce_results(self, n_epoch, model, nb_samples = 64):
378         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
379         for input in results.split(self.batch_size):
380             for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
381                 output = model(input)
382                 logits = output[:, s]
383                 if args.synthesis_sampling:
384                     dist = torch.distributions.categorical.Categorical(logits = logits)
385                     t = dist.sample()
386                 else:
387                     t = logits.argmax(1)
388                 input[:, s] = t
389
390         image_name = f'result_mnist_{n_epoch:04d}.png'
391         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392                                      image_name, nrow = 16, pad_value = 0.8)
393         log_string(f'wrote {image_name}')
394
395 ######################################################################
396
397 log_string(f'device {device}')
398
399 if args.data == 'wiki103':
400     nb_epochs_default = 10
401     task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403     nb_epochs_default = 25
404     task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406     nb_epochs_default = 10
407     task = TaskPicoCLVR(batch_size = args.batch_size,
408                         height = args.picoclvr_height,
409                         width = args.picoclvr_width,
410                         many_colors = args.picoclvr_many_colors,
411                         device = device)
412 else:
413     raise ValueError(f'Unknown dataset {args.data}.')
414
415 vocabulary_size = task.vocabulary_size()
416
417 log_string(f'vocabulary_size {vocabulary_size}')
418
419 ##############################
420
421 model = mygpt.MyGPT(
422     vocabulary_size = vocabulary_size,
423     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
425 )
426
427 model.to(device)
428
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
431
432 ######################################################################
433
434 if args.optim == 'sgd':
435     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
440 else:
441     raise ValueError(f'Unknown optimizer {args.optim}.')
442
443 ######################################################################
444
445 nb_epochs_finished = 0
446
447 if args.no_checkpoint:
448     log_string(f'Not trying to load checkpoint.')
449
450 else:
451     try:
452         checkpoint = torch.load(args.checkpoint_name, map_location = device)
453         nb_epochs_finished = checkpoint['nb_epochs_finished']
454         model.load_state_dict(checkpoint['model_state'])
455         optimizer.load_state_dict(checkpoint['optimizer_state'])
456         log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
457
458     except FileNotFoundError:
459         log_string('Starting from scratch.')
460
461     except:
462         log_string('Error when loading the checkpoint.')
463         exit(1)
464
465 ######################################################################
466
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468
469 token_count = 0
470 for input in task.batches(split = 'train'):
471     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'Train set perplexity {train_set_perplexity}')
476
477 for k in range(nb_epochs_finished, nb_epochs):
478
479     model.train()
480
481     nb_train_samples, acc_train_loss = 0, 0.0
482
483     for input in task.batches(split = 'train'):
484         input = input.to(device)
485         output = model(input)
486         loss = F.cross_entropy(output.transpose(1, 2), input)
487         acc_train_loss += loss.item() * input.size(0)
488         nb_train_samples += input.size(0)
489
490         optimizer.zero_grad()
491         loss.backward()
492         optimizer.step()
493
494     with torch.autograd.no_grad():
495
496         model.eval()
497
498         nb_test_samples, acc_test_loss = 0, 0.0
499
500         for input in task.batches(split = 'test'):
501             input = input.to(device)
502             output = model(input)
503             loss = F.cross_entropy(output.transpose(1, 2), input)
504             acc_test_loss += loss.item() * input.size(0)
505             nb_test_samples += input.size(0)
506
507         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
509
510         log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
511
512         task.produce_results(k, model)
513
514     checkpoint = {
515         'nb_epochs_finished': k + 1,
516         'model_state': model.state_dict(),
517         'optimizer_state': optimizer.state_dict()
518     }
519
520     torch.save(checkpoint, args.checkpoint_name)
521
522 ######################################################################