Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in range(first, input.size(1)):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     # Make a tensor from a list of strings
160     def tensorize(self, descr):
161         token_descr = [ s.strip().split(' ') for s in descr ]
162         l = max([ len(s) for s in token_descr ])
163         #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164         token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165         id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166         return torch.tensor(id_descr, device = self.device)
167
168     def trim(self, x, token = '<nul>'):
169         n = self.token2id[token]
170         i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
171         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
172         return x[:, a:b]
173
174     def __init__(self, batch_size,
175                  height, width, nb_colors = 5,
176                  device = torch.device('cpu')):
177
178         def generate_descr(nb):
179             return picoclvr.generate(
180                 nb,
181                 height = self.height, width = self.width,
182                 nb_colors = nb_colors
183             )
184
185         self.height = height
186         self.width = width
187         self.batch_size = batch_size
188         self.device = device
189         nb = args.data_size if args.data_size > 0 else 250000
190
191         log_string('generating {nb} samples (can take some time)')
192         self.train_descr = generate_descr((nb * 4) // 5)
193         self.test_descr = generate_descr((nb * 1) // 5)
194
195         # Build the tokenizer
196         tokens = { '<nul>' }
197         for d in [ self.train_descr, self.test_descr ]:
198             for s in d:
199                 for t in s.strip().split(' '): tokens.add(t)
200         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
201         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
202
203         # Tokenize the train and test sets
204         self.train_input = self.tensorize(self.train_descr)
205         self.test_input = self.tensorize(self.test_descr)
206
207     def batches(self, split = 'train'):
208         assert split in { 'train', 'test' }
209         input = self.train_input if split == 'train' else self.test_input
210         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
211             yield self.trim(batch)
212
213     def vocabulary_size(self):
214         return len(self.token2id)
215
216     def produce_results(self, n_epoch, model):
217         nb_tokens_to_generate = self.height * self.width + 3
218         result_descr = [ ]
219         nb_per_primer = 8
220
221         for primer_descr in [
222                 'red above green <sep> green top <sep> blue right of red <img>',
223                 'there is red <sep> there is yellow <sep> there is blue <img>',
224                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
225                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
226         ]:
227
228             results = autoregression(
229                 model,
230                 self.batch_size,
231                 nb_samples = nb_per_primer,
232                 nb_tokens_to_generate = nb_tokens_to_generate,
233                 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
234                 device = self.device
235             )
236
237             l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
238             result_descr += l
239
240         np = picoclvr.nb_properties(
241             result_descr,
242             height = self.height, width = self.width
243         )
244
245         nb_requested_properties, _, nb_missing_properties = zip(*np)
246
247         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
248
249         img = [
250             picoclvr.descr2img(d, height = self.height, width = self.width)
251             for d in result_descr
252         ]
253
254         img = torch.cat(img, 0)
255         image_name = f'result_picoclvr_{n_epoch:04d}.png'
256         torchvision.utils.save_image(
257             img / 255.,
258             image_name, nrow = nb_per_primer, pad_value = 0.8
259         )
260         log_string(f'wrote {image_name}')
261
262 ######################################################################
263
264 class TaskWiki103(Task):
265
266     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
267                  device = torch.device('cpu')):
268
269         self.batch_size = batch_size
270         self.len_min = len_min
271         self.len_max = len_max
272         self.min_freq = min_freq
273         self.device = device
274
275         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
276         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
277
278         # Mostly for debug
279         if args.data_size > 0:
280             train_iter = itertools.islice(train_iter, args.data_size)
281
282         def yield_tokens():
283             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
284                 yield self.tokenizer(l)
285
286         self.vocab = torchtext.vocab.build_vocab_from_iterator(
287             yield_tokens(),
288             specials = [ '<unk>', '<nul>' ],
289             min_freq = self.min_freq
290         )
291
292         self.vocab.set_default_index(self.vocab[ '<unk>' ])
293
294     # makes a tensor from a list of list of tokens
295     def tensorize(self, s):
296         a = max(len(x) for x in s)
297         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
298
299     def yield_batches(self, ds):
300         s = [ ]
301         for l in ds:
302             q = self.tokenizer(l)
303             if len(q) >= self.len_min and len(q) <= self.len_max:
304                 s += [ q ]
305                 if len(s) == self.batch_size:
306                     yield self.tensorize(s)
307                     s = [ ]
308
309         if len(s) > 0:
310             yield self.tensorize(s)
311
312     def batches(self, split = 'train'):
313         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
314
315         # Mostly for debug
316         if args.data_size > 0:
317             data_iter = itertools.islice(data_iter, args.data_size)
318
319         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
320
321     def vocabulary_size(self):
322         return len(self.vocab)
323
324     def produce_results(self, n_epoch, model):
325         nb_tokens = 50
326         file_name = f'result_wiki103_{n_epoch:04d}.txt'
327
328         with open(file_name, 'w') as outfile:
329              for primer in [
330                      'the cat is hunting a',
331                      'paris is the capital',
332                      'cars are convenient',
333                      'the difference between men and women is',
334                      'the object was blue all over and green all over it was',
335                      'cherries are red and lemons are',
336                      'cherries are sweet and lemons are',
337                      'two plus three equals',
338                      'deep learning is',
339              ]:
340                  t_primer = self.tokenizer(primer)
341                  t_generated = [ ]
342
343                  for j in range(nb_tokens):
344
345                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
346                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
347                      output = model(input)
348                      logits = output[0, -1]
349                      if args.synthesis_sampling:
350                          dist = torch.distributions.categorical.Categorical(logits = logits)
351                          t_next = dist.sample()
352                      else:
353                          t_next = logits.argmax()
354                      t_generated.append(self.vocab.lookup_token(t_next))
355                      if t_generated[-1] == '<nul>': break
356
357                  s = ' '.join(t_generated)
358
359                  outfile.write(f'<{primer}> {s}\n')
360
361         log_string(f'wrote {file_name}')
362
363 ######################################################################
364
365 class TaskMNIST(Task):
366
367     def __init__(self, batch_size, device = torch.device('cpu')):
368         self.device = device
369         self.batch_size = batch_size
370
371     def batches(self, split = 'train'):
372         assert split in { 'train', 'test' }
373         data_set = torchvision.datasets.MNIST(
374             root = './data', train = (split == 'train'),
375             download = True
376         )
377         data_input = data_set.data.view(-1, 28 * 28).long()
378         if args.data_size >= 0:
379             data_input = data_input[:args.data_size]
380         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
381             yield batch
382
383     def vocabulary_size(self):
384         return 256
385
386     def produce_results(self, n_epoch, model):
387         nb_samples = 64
388         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
389         image_name = f'result_mnist_{n_epoch:04d}.png'
390         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
391                                      image_name, nrow = 16, pad_value = 0.8)
392         log_string(f'wrote {image_name}')
393
394 ######################################################################
395
396 log_string(f'device {device}')
397
398 if args.data == 'wiki103':
399     nb_epochs_default = 10
400     task = TaskWiki103(batch_size = args.batch_size, device = device)
401 elif args.data == 'mnist':
402     nb_epochs_default = 25
403     task = TaskMNIST(batch_size = args.batch_size, device = device)
404 elif args.data == 'picoclvr':
405     nb_epochs_default = 10
406     task = TaskPicoCLVR(batch_size = args.batch_size,
407                         height = args.picoclvr_height,
408                         width = args.picoclvr_width,
409                         nb_colors = args.picoclvr_nb_colors,
410                         device = device)
411 else:
412     raise ValueError(f'Unknown dataset {args.data}.')
413
414 vocabulary_size = task.vocabulary_size()
415
416 log_string(f'vocabulary_size {vocabulary_size}')
417
418 ##############################
419
420 model = mygpt.MyGPT(
421     vocabulary_size = vocabulary_size,
422     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
423     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
424 )
425
426 model.to(device)
427
428 nb_parameters = sum(p.numel() for p in model.parameters())
429 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
430
431 ######################################################################
432
433 if args.optim == 'sgd':
434     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
435 elif args.optim == 'adam':
436     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
437 elif args.optim == 'adamw':
438     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
439 else:
440     raise ValueError(f'Unknown optimizer {args.optim}.')
441
442 ######################################################################
443
444 nb_epochs_finished = 0
445
446 if args.no_checkpoint:
447     log_string(f'not trying to load checkpoint.')
448
449 else:
450     try:
451         checkpoint = torch.load(args.checkpoint_name, map_location = device)
452         nb_epochs_finished = checkpoint['nb_epochs_finished']
453         model.load_state_dict(checkpoint['model_state'])
454         optimizer.load_state_dict(checkpoint['optimizer_state'])
455         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
456
457     except FileNotFoundError:
458         log_string('starting from scratch.')
459
460     except:
461         log_string('error when loading the checkpoint.')
462         exit(1)
463
464 ######################################################################
465
466 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
467
468 token_count = 0
469 for input in task.batches(split = 'train'):
470     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
471 token_probas = token_count / token_count.sum()
472 entropy = -torch.xlogy(token_probas, token_probas).sum()
473 train_set_perplexity = math.exp(entropy)
474
475 for k in range(nb_epochs_finished, nb_epochs):
476
477     model.train()
478
479     nb_train_samples, acc_train_loss = 0, 0.0
480
481     for input in task.batches(split = 'train'):
482         input = input.to(device)
483         output = model(input)
484         loss = F.cross_entropy(output.transpose(1, 2), input)
485         acc_train_loss += loss.item() * input.size(0)
486         nb_train_samples += input.size(0)
487
488         optimizer.zero_grad()
489         loss.backward()
490         optimizer.step()
491
492     with torch.autograd.no_grad():
493
494         model.eval()
495
496         nb_test_samples, acc_test_loss = 0, 0.0
497
498         for input in task.batches(split = 'test'):
499             input = input.to(device)
500             output = model(input)
501             loss = F.cross_entropy(output.transpose(1, 2), input)
502             acc_test_loss += loss.item() * input.size(0)
503             nb_test_samples += input.size(0)
504
505         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
507
508         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
509
510         task.produce_results(k, model)
511
512     checkpoint = {
513         'nb_epochs_finished': k + 1,
514         'model_state': model.state_dict(),
515         'optimizer_state': optimizer.state_dict()
516     }
517
518     torch.save(checkpoint, args.checkpoint_name)
519
520 ######################################################################