aa1b51799a159b2f176154d63cce5893eabd38ad
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-3)
46
47 parser.add_argument('--learning_rate_end',
48                     type = float, default = 1e-6)
49
50 parser.add_argument('--dim_model',
51                     type = int, default = 512)
52
53 parser.add_argument('--dim_keys',
54                     type = int, default = 64)
55
56 parser.add_argument('--dim_hidden',
57                     type = int, default = 2048)
58
59 parser.add_argument('--nb_heads',
60                     type = int, default = 8)
61
62 parser.add_argument('--nb_blocks',
63                     type = int, default = 12)
64
65 parser.add_argument('--dropout',
66                     type = float, default = 0.1)
67
68 parser.add_argument('--deterministic_synthesis',
69                     action='store_true', default = False)
70
71 parser.add_argument('--no_checkpoint',
72                     action='store_true', default = False)
73
74 parser.add_argument('--checkpoint_name',
75                     type = str, default = 'checkpoint.pth')
76
77 ##############################
78 # picoclvr options
79
80 parser.add_argument('--picoclvr_nb_colors',
81                     type = int, default = 5)
82
83 parser.add_argument('--picoclvr_height',
84                     type = int, default = 12)
85
86 parser.add_argument('--picoclvr_width',
87                     type = int, default = 16)
88
89 ######################################################################
90
91 args = parser.parse_args()
92
93 log_file = open(args.log_filename, 'w')
94
95 if args.seed >= 0:
96     torch.manual_seed(args.seed)
97
98 ######################################################################
99
100 def log_string(s):
101     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
102
103     if log_file is not None:
104         log_file.write(t + s + '\n')
105         log_file.flush()
106
107     print(t + s)
108     sys.stdout.flush()
109
110 for n in vars(args):
111     log_string(f'args.{n} {getattr(args, n)}')
112
113 ######################################################################
114
115 def autoregression(
116         model, batch_size,
117         nb_samples, nb_tokens_to_generate, primer = None,
118         device = torch.device('cpu')
119 ):
120     results = torch.zeros(
121         nb_samples, nb_tokens_to_generate,
122         dtype = torch.int64, device = device
123     )
124
125     if primer is None:
126         first = 0
127     else:
128         first = primer.size(1)
129         results = torch.cat((primer, results), 1)
130
131     for input in results.split(batch_size):
132         for s in range(first, input.size(1)):
133             output = model(input)
134             logits = output[:, s]
135             if args.deterministic_synthesis:
136                 t_next = logits.argmax(1)
137             else:
138                 dist = torch.distributions.categorical.Categorical(logits = logits)
139                 t_next = dist.sample()
140             input[:, s] = t_next
141
142     return results
143
144 ######################################################################
145
146 class Task:
147     def batches(self, split = 'train'):
148         pass
149
150     def vocabulary_size(self):
151         pass
152
153     def produce_results(self, n_epoch, model):
154         pass
155
156 ######################################################################
157
158 import picoclvr
159
160 class TaskPicoCLVR(Task):
161
162     # Make a tensor from a list of strings
163     def tensorize(self, descr):
164         token_descr = [ s.strip().split(' ') for s in descr ]
165         l = max([ len(s) for s in token_descr ])
166         #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
167         token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
168         id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
169         return torch.tensor(id_descr, device = self.device)
170
171     def trim(self, x, token = '<nul>'):
172         n = self.token2id[token]
173         i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
174         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
175         return x[:, a:b]
176
177     def __init__(self, batch_size,
178                  height, width, nb_colors = 5,
179                  device = torch.device('cpu')):
180
181         def generate_descr(nb):
182             return picoclvr.generate(
183                 nb,
184                 height = self.height, width = self.width,
185                 nb_colors = nb_colors
186             )
187
188         self.height = height
189         self.width = width
190         self.batch_size = batch_size
191         self.device = device
192         nb = args.data_size if args.data_size > 0 else 250000
193
194         log_string(f'generating {nb} samples (can take some time)')
195         self.train_descr = generate_descr((nb * 4) // 5)
196         self.test_descr = generate_descr((nb * 1) // 5)
197
198         # Build the tokenizer
199         tokens = { '<nul>' }
200         for d in [ self.train_descr, self.test_descr ]:
201             for s in d:
202                 for t in s.strip().split(' '): tokens.add(t)
203         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
204         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
205
206         # Tokenize the train and test sets
207         self.train_input = self.tensorize(self.train_descr)
208         self.test_input = self.tensorize(self.test_descr)
209
210     def batches(self, split = 'train'):
211         assert split in { 'train', 'test' }
212         input = self.train_input if split == 'train' else self.test_input
213         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
214             yield self.trim(batch)
215
216     def vocabulary_size(self):
217         return len(self.token2id)
218
219     def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
220         nb_tokens_to_generate = self.height * self.width + 3
221         result_descr = [ ]
222
223         for primer_descr in primers_descr:
224
225             results = autoregression(
226                 model,
227                 self.batch_size,
228                 nb_samples = nb_per_primer,
229                 nb_tokens_to_generate = nb_tokens_to_generate,
230                 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
231                 device = self.device
232             )
233
234             l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
235             result_descr += l
236
237         np = picoclvr.nb_properties(
238             result_descr,
239             height = self.height, width = self.width
240         )
241
242         nb_requested_properties, _, nb_missing_properties = zip(*np)
243
244         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
245
246         np=torch.tensor(np)
247         count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
248         for i in range(count.size(0)):
249             for j in range(count.size(1)):
250                 count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
251
252         if generate_images:
253             img = [
254                 picoclvr.descr2img(d, height = self.height, width = self.width)
255                 for d in result_descr
256             ]
257
258             img = torch.cat(img, 0)
259             image_name = f'result_picoclvr_{n_epoch:04d}.png'
260             torchvision.utils.save_image(
261                 img / 255.,
262                 image_name, nrow = nb_per_primer, pad_value = 0.8
263             )
264             log_string(f'wrote {image_name}')
265
266         return count
267
268     def produce_results(self, n_epoch, model):
269         primers_descr = [
270             'red above green <sep> green top <sep> blue right of red <img>',
271             'there is red <sep> there is yellow <sep> there is blue <img>',
272             'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
273             'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
274         ]
275
276         self.test_model(
277             n_epoch, model,
278             primers_descr,
279             nb_per_primer=8, generate_images=True
280         )
281
282         # FAR TOO SLOW!!!
283
284         # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
285
286         # count=self.test_model(
287             # n_epoch, model,
288             # test_primers_descr,
289             # nb_per_primer=1, generate_images=False
290         # )
291
292         # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
293             # for i in range(count.size(0)):
294                 # for j in range(count.size(1)):
295                     # f.write(f'{count[i,j]}')
296                     # f.write(" " if j<count.size(1)-1 else "\n")
297
298 ######################################################################
299
300 class TaskWiki103(Task):
301
302     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
303                  device = torch.device('cpu')):
304
305         self.batch_size = batch_size
306         self.len_min = len_min
307         self.len_max = len_max
308         self.min_freq = min_freq
309         self.device = device
310
311         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
312         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
313
314         # Mostly for debug
315         if args.data_size > 0:
316             train_iter = itertools.islice(train_iter, args.data_size)
317
318         def yield_tokens():
319             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
320                 yield self.tokenizer(l)
321
322         self.vocab = torchtext.vocab.build_vocab_from_iterator(
323             yield_tokens(),
324             specials = [ '<unk>', '<nul>' ],
325             min_freq = self.min_freq
326         )
327
328         self.vocab.set_default_index(self.vocab[ '<unk>' ])
329
330     # makes a tensor from a list of list of tokens
331     def tensorize(self, s):
332         a = max(len(x) for x in s)
333         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
334
335     def yield_batches(self, ds):
336         s = [ ]
337         for l in ds:
338             q = self.tokenizer(l)
339             if len(q) >= self.len_min and len(q) <= self.len_max:
340                 s += [ q ]
341                 if len(s) == self.batch_size:
342                     yield self.tensorize(s)
343                     s = [ ]
344
345         if len(s) > 0:
346             yield self.tensorize(s)
347
348     def batches(self, split = 'train'):
349         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
350
351         # Mostly for debug
352         if args.data_size > 0:
353             data_iter = itertools.islice(data_iter, args.data_size)
354
355         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
356
357     def vocabulary_size(self):
358         return len(self.vocab)
359
360     def produce_results(self, n_epoch, model):
361         nb_tokens = 50
362         file_name = f'result_wiki103_{n_epoch:04d}.txt'
363
364         with open(file_name, 'w') as outfile:
365              for primer in [
366                      'the cat is hunting a',
367                      'paris is the capital',
368                      'cars are convenient',
369                      'the difference between men and women is',
370                      'the object was blue all over and green all over it was',
371                      'cherries are red and lemons are',
372                      'cherries are sweet and lemons are',
373                      'two plus three equals',
374                      'deep learning is',
375              ]:
376                  t_primer = self.tokenizer(primer)
377                  t_generated = [ ]
378
379                  for j in range(nb_tokens):
380
381                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
382                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
383                      output = model(input)
384                      logits = output[0, -1]
385                      if args.deterministic_synthesis:
386                          t_next = logits.argmax()
387                      else:
388                          dist = torch.distributions.categorical.Categorical(logits = logits)
389                          t_next = dist.sample()
390                      t_generated.append(self.vocab.lookup_token(t_next))
391                      if t_generated[-1] == '<nul>': break
392
393                  s = ' '.join(t_generated)
394
395                  outfile.write(f'<{primer}> {s}\n')
396
397         log_string(f'wrote {file_name}')
398
399 ######################################################################
400
401 class TaskMNIST(Task):
402
403     def __init__(self, batch_size, device = torch.device('cpu')):
404         self.device = device
405         self.batch_size = batch_size
406
407     def batches(self, split = 'train'):
408         assert split in { 'train', 'test' }
409         data_set = torchvision.datasets.MNIST(
410             root = './data', train = (split == 'train'),
411             download = True
412         )
413         data_input = data_set.data.view(-1, 28 * 28).long()
414         if args.data_size >= 0:
415             data_input = data_input[:args.data_size]
416         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
417             yield batch
418
419     def vocabulary_size(self):
420         return 256
421
422     def produce_results(self, n_epoch, model):
423         nb_samples = 64
424         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
425         image_name = f'result_mnist_{n_epoch:04d}.png'
426         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
427                                      image_name, nrow = 16, pad_value = 0.8)
428         log_string(f'wrote {image_name}')
429
430 ######################################################################
431
432 log_string(f'device {device}')
433
434 if args.data == 'wiki103':
435     nb_epochs_default = 10
436     task = TaskWiki103(batch_size = args.batch_size, device = device)
437 elif args.data == 'mnist':
438     nb_epochs_default = 25
439     task = TaskMNIST(batch_size = args.batch_size, device = device)
440 elif args.data == 'picoclvr':
441     nb_epochs_default = 10
442     task = TaskPicoCLVR(batch_size = args.batch_size,
443                         height = args.picoclvr_height,
444                         width = args.picoclvr_width,
445                         nb_colors = args.picoclvr_nb_colors,
446                         device = device)
447 else:
448     raise ValueError(f'Unknown dataset {args.data}.')
449
450 vocabulary_size = task.vocabulary_size()
451
452 log_string(f'vocabulary_size {vocabulary_size}')
453
454 ##############################
455
456 model = mygpt.MyGPT(
457     vocabulary_size = vocabulary_size,
458     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
459     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
460 )
461
462 model.to(device)
463
464 nb_parameters = sum(p.numel() for p in model.parameters())
465 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
466
467 ######################################################################
468
469 nb_epochs_finished = 0
470
471 if args.no_checkpoint:
472     log_string(f'not trying to load checkpoint.')
473
474 else:
475     try:
476         checkpoint = torch.load(args.checkpoint_name)
477         nb_epochs_finished = checkpoint['nb_epochs_finished']
478         model.load_state_dict(checkpoint['model_state'])
479         torch.set_rng_state(checkpoint['rng_state'])
480         if torch.cuda.is_available():
481             torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
482         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
483
484     except FileNotFoundError:
485         log_string('starting from scratch.')
486
487     except:
488         log_string('error when loading the checkpoint.')
489         exit(1)
490
491 ######################################################################
492
493 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
494
495 token_count = 0
496 for input in task.batches(split = 'train'):
497     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
498 token_probas = token_count / token_count.sum()
499 entropy = -torch.xlogy(token_probas, token_probas).sum()
500 train_set_perplexity = math.exp(entropy)
501
502 for n_epoch in range(nb_epochs_finished, nb_epochs):
503
504     if args.learning_rate_end < 0:
505         lr = args.learning_rate
506     else:
507         u = n_epoch / (nb_epochs - 1)
508         lr = math.exp((1 - u) * math.log(args.learning_rate) +
509                       u * math.log(args.learning_rate_end))
510         log_string(f'learning_rate {lr}')
511
512     if args.optim == 'sgd':
513         optimizer = torch.optim.SGD(model.parameters(), lr = lr)
514     elif args.optim == 'adam':
515         optimizer = torch.optim.Adam(model.parameters(), lr = lr)
516     elif args.optim == 'adamw':
517         optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
518     else:
519         raise ValueError(f'Unknown optimizer {args.optim}.')
520
521     model.train()
522
523     nb_train_samples, acc_train_loss = 0, 0.0
524
525     for input in task.batches(split = 'train'):
526         input = input.to(device)
527         output = model(input)
528         loss = F.cross_entropy(output.transpose(1, 2), input)
529         acc_train_loss += loss.item() * input.size(0)
530         nb_train_samples += input.size(0)
531
532         optimizer.zero_grad()
533         loss.backward()
534         optimizer.step()
535
536     with torch.autograd.no_grad():
537
538         model.eval()
539
540         nb_test_samples, acc_test_loss = 0, 0.0
541
542         for input in task.batches(split = 'test'):
543             input = input.to(device)
544             output = model(input)
545             loss = F.cross_entropy(output.transpose(1, 2), input)
546             acc_test_loss += loss.item() * input.size(0)
547             nb_test_samples += input.size(0)
548
549         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
550         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
551
552         log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
553
554         task.produce_results(n_epoch, model)
555
556     checkpoint = {
557         'nb_epochs_finished': n_epoch + 1,
558         'model_state': model.state_dict(),
559         'rng_state': torch.get_rng_state(),
560     }
561
562     if torch.cuda.is_available():
563         checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
564
565     torch.save(checkpoint, args.checkpoint_name)
566
567 ######################################################################