OCDC
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in range(first, input.size(1)):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     # Make a tensor from a list of strings
160     def tensorize(self, descr):
161         token_descr = [ s.strip().split(' ') for s in descr ]
162         l = max([ len(s) for s in token_descr ])
163         #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164         token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165         id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166         return torch.tensor(id_descr, device = self.device)
167
168     def trim(self, x, token = '<nul>'):
169         n = self.token2id[token]
170         i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
171         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
172         return x[:, a:b]
173
174     def __init__(self, batch_size,
175                  height, width, nb_colors = 5,
176                  device = torch.device('cpu')):
177
178         def generate_descr(nb):
179             return picoclvr.generate(
180                 nb,
181                 height = self.height, width = self.width,
182                 nb_colors = nb_colors
183             )
184
185         self.height = height
186         self.width = width
187         self.batch_size = batch_size
188         self.device = device
189         nb = args.data_size if args.data_size > 0 else 250000
190
191         self.train_descr = generate_descr((nb * 4) // 5)
192         self.test_descr = generate_descr((nb * 1) // 5)
193
194         # Build the tokenizer
195         tokens = { '<nul>' }
196         for d in [ self.train_descr, self.test_descr ]:
197             for s in d:
198                 for t in s.strip().split(' '): tokens.add(t)
199         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
200         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
201
202         # Tokenize the train and test sets
203         self.train_input = self.tensorize(self.train_descr)
204         self.test_input = self.tensorize(self.test_descr)
205
206     def batches(self, split = 'train'):
207         assert split in { 'train', 'test' }
208         input = self.train_input if split == 'train' else self.test_input
209         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
210             yield self.trim(batch)
211
212     def vocabulary_size(self):
213         return len(self.token2id)
214
215     def produce_results(self, n_epoch, model):
216         nb_tokens_to_generate = self.height * self.width + 3
217         result_descr = [ ]
218         nb_per_primer = 8
219
220         for primer_descr in [
221                 'red above green <sep> green top <sep> blue right of red <img>',
222                 'there is red <sep> there is yellow <sep> there is blue <img>',
223                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
224                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
225         ]:
226
227             results = autoregression(
228                 model,
229                 self.batch_size,
230                 nb_samples = nb_per_primer,
231                 nb_tokens_to_generate = nb_tokens_to_generate,
232                 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
233                 device = self.device
234             )
235
236             l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
237             result_descr += l
238
239         np = picoclvr.nb_properties(
240             result_descr,
241             height = self.height, width = self.width
242         )
243
244         nb_requested_properties, _, nb_missing_properties = zip(*np)
245
246         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
247
248         img = [
249             picoclvr.descr2img(d, height = self.height, width = self.width)
250             for d in result_descr
251         ]
252
253         img = torch.cat(img, 0)
254         image_name = f'result_picoclvr_{n_epoch:04d}.png'
255         torchvision.utils.save_image(
256             img / 255.,
257             image_name, nrow = nb_per_primer, pad_value = 0.8
258         )
259         log_string(f'wrote {image_name}')
260
261 ######################################################################
262
263 class TaskWiki103(Task):
264
265     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
266                  device = torch.device('cpu')):
267
268         self.batch_size = batch_size
269         self.len_min = len_min
270         self.len_max = len_max
271         self.min_freq = min_freq
272         self.device = device
273
274         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
275         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
276
277         # Mostly for debug
278         if args.data_size > 0:
279             train_iter = itertools.islice(train_iter, args.data_size)
280
281         def yield_tokens():
282             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
283                 yield self.tokenizer(l)
284
285         self.vocab = torchtext.vocab.build_vocab_from_iterator(
286             yield_tokens(),
287             specials = [ '<unk>', '<nul>' ],
288             min_freq = self.min_freq
289         )
290
291         self.vocab.set_default_index(self.vocab[ '<unk>' ])
292
293     # makes a tensor from a list of list of tokens
294     def tensorize(self, s):
295         a = max(len(x) for x in s)
296         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
297
298     def yield_batches(self, ds):
299         s = [ ]
300         for l in ds:
301             q = self.tokenizer(l)
302             if len(q) >= self.len_min and len(q) <= self.len_max:
303                 s += [ q ]
304                 if len(s) == self.batch_size:
305                     yield self.tensorize(s)
306                     s = [ ]
307
308         if len(s) > 0:
309             yield self.tensorize(s)
310
311     def batches(self, split = 'train'):
312         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
313
314         # Mostly for debug
315         if args.data_size > 0:
316             data_iter = itertools.islice(data_iter, args.data_size)
317
318         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
319
320     def vocabulary_size(self):
321         return len(self.vocab)
322
323     def produce_results(self, n_epoch, model):
324         nb_tokens = 50
325         file_name = f'result_wiki103_{n_epoch:04d}.txt'
326
327         with open(file_name, 'w') as outfile:
328              for primer in [
329                      'the cat is hunting a',
330                      'paris is the capital',
331                      'cars are convenient',
332                      'the difference between men and women is',
333                      'the object was blue all over and green all over it was',
334                      'cherries are red and lemons are',
335                      'cherries are sweet and lemons are',
336                      'two plus three equals',
337                      'deep learning is',
338              ]:
339                  t_primer = self.tokenizer(primer)
340                  t_generated = [ ]
341
342                  for j in range(nb_tokens):
343
344                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
345                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
346                      output = model(input)
347                      logits = output[0, -1]
348                      if args.synthesis_sampling:
349                          dist = torch.distributions.categorical.Categorical(logits = logits)
350                          t_next = dist.sample()
351                      else:
352                          t_next = logits.argmax()
353                      t_generated.append(self.vocab.lookup_token(t_next))
354                      if t_generated[-1] == '<nul>': break
355
356                  s = ' '.join(t_generated)
357
358                  outfile.write(f'<{primer}> {s}\n')
359
360         log_string(f'wrote {file_name}')
361
362 ######################################################################
363
364 class TaskMNIST(Task):
365
366     def __init__(self, batch_size, device = torch.device('cpu')):
367         self.device = device
368         self.batch_size = batch_size
369
370     def batches(self, split = 'train'):
371         assert split in { 'train', 'test' }
372         data_set = torchvision.datasets.MNIST(
373             root = './data', train = (split == 'train'),
374             download = True
375         )
376         data_input = data_set.data.view(-1, 28 * 28).long()
377         if args.data_size >= 0:
378             data_input = data_input[:args.data_size]
379         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
380             yield batch
381
382     def vocabulary_size(self):
383         return 256
384
385     def produce_results(self, n_epoch, model):
386         nb_samples = 64
387         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
388         image_name = f'result_mnist_{n_epoch:04d}.png'
389         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390                                      image_name, nrow = 16, pad_value = 0.8)
391         log_string(f'wrote {image_name}')
392
393 ######################################################################
394
395 log_string(f'device {device}')
396
397 if args.data == 'wiki103':
398     nb_epochs_default = 10
399     task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401     nb_epochs_default = 25
402     task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404     nb_epochs_default = 10
405     task = TaskPicoCLVR(batch_size = args.batch_size,
406                         height = args.picoclvr_height,
407                         width = args.picoclvr_width,
408                         nb_colors = args.picoclvr_nb_colors,
409                         device = device)
410 else:
411     raise ValueError(f'Unknown dataset {args.data}.')
412
413 vocabulary_size = task.vocabulary_size()
414
415 log_string(f'vocabulary_size {vocabulary_size}')
416
417 ##############################
418
419 model = mygpt.MyGPT(
420     vocabulary_size = vocabulary_size,
421     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
423 )
424
425 model.to(device)
426
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
429
430 ######################################################################
431
432 if args.optim == 'sgd':
433     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
438 else:
439     raise ValueError(f'Unknown optimizer {args.optim}.')
440
441 ######################################################################
442
443 nb_epochs_finished = 0
444
445 if args.no_checkpoint:
446     log_string(f'not trying to load checkpoint.')
447
448 else:
449     try:
450         checkpoint = torch.load(args.checkpoint_name, map_location = device)
451         nb_epochs_finished = checkpoint['nb_epochs_finished']
452         model.load_state_dict(checkpoint['model_state'])
453         optimizer.load_state_dict(checkpoint['optimizer_state'])
454         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
455
456     except FileNotFoundError:
457         log_string('starting from scratch.')
458
459     except:
460         log_string('error when loading the checkpoint.')
461         exit(1)
462
463 ######################################################################
464
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
466
467 token_count = 0
468 for input in task.batches(split = 'train'):
469     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 entropy = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(entropy)
473
474 for k in range(nb_epochs_finished, nb_epochs):
475
476     model.train()
477
478     nb_train_samples, acc_train_loss = 0, 0.0
479
480     for input in task.batches(split = 'train'):
481         input = input.to(device)
482         output = model(input)
483         loss = F.cross_entropy(output.transpose(1, 2), input)
484         acc_train_loss += loss.item() * input.size(0)
485         nb_train_samples += input.size(0)
486
487         optimizer.zero_grad()
488         loss.backward()
489         optimizer.step()
490
491     with torch.autograd.no_grad():
492
493         model.eval()
494
495         nb_test_samples, acc_test_loss = 0, 0.0
496
497         for input in task.batches(split = 'test'):
498             input = input.to(device)
499             output = model(input)
500             loss = F.cross_entropy(output.transpose(1, 2), input)
501             acc_test_loss += loss.item() * input.size(0)
502             nb_test_samples += input.size(0)
503
504         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
505         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
506
507         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
508
509         task.produce_results(k, model)
510
511     checkpoint = {
512         'nb_epochs_finished': k + 1,
513         'model_state': model.state_dict(),
514         'optimizer_state': optimizer.state_dict()
515     }
516
517     torch.save(checkpoint, args.checkpoint_name)
518
519 ######################################################################