Cosmetics.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--seed',
28                     type = int, default = 0)
29
30 parser.add_argument('--nb_epochs',
31                     type = int, default = -1)
32
33 parser.add_argument('--batch_size',
34                     type = int, default = 25)
35
36 parser.add_argument('--data',
37                     type = str, default = 'wiki103')
38
39 parser.add_argument('--data_size',
40                     type = int, default = -1)
41
42 parser.add_argument('--optim',
43                     type = str, default = 'adam')
44
45 parser.add_argument('--learning_rate',
46                     type = float, default = 1e-4)
47
48 parser.add_argument('--dim_model',
49                     type = int, default = 512)
50
51 parser.add_argument('--dim_keys',
52                     type = int, default = 64)
53
54 parser.add_argument('--dim_hidden',
55                     type = int, default = 2048)
56
57 parser.add_argument('--nb_heads',
58                     type = int, default = 8)
59
60 parser.add_argument('--nb_blocks',
61                     type = int, default = 12)
62
63 parser.add_argument('--dropout',
64                     type = float, default = 0.1)
65
66 parser.add_argument('--synthesis_sampling',
67                     action='store_true', default = True)
68
69 parser.add_argument('--no_checkpoint',
70                     action='store_true', default = False)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_nb_colors',
79                     type = int, default = 5)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 def autoregression(
114         model, batch_size,
115         nb_samples, nb_tokens_to_generate, primer = None,
116         device = torch.device('cpu')
117 ):
118     results = torch.zeros(
119         nb_samples, nb_tokens_to_generate,
120         dtype = torch.int64, device = device
121     )
122
123     if primer is None:
124         first = 0
125     else:
126         first = primer.size(1)
127         results = torch.cat((primer, results), 1)
128
129     for input in results.split(batch_size):
130         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131             output = model(input)
132             logits = output[:, s]
133             if args.synthesis_sampling:
134                 dist = torch.distributions.categorical.Categorical(logits = logits)
135                 t_next = dist.sample()
136             else:
137                 t_next = logits.argmax(1)
138             input[:, s] = t_next
139
140     return results
141
142 ######################################################################
143
144 class Task:
145     def batches(self, split = 'train'):
146         pass
147
148     def vocabulary_size(self):
149         pass
150
151     def produce_results(self, n_epoch, model, nb_tokens = 50):
152         pass
153
154 ######################################################################
155
156 import picoclvr
157
158 class TaskPicoCLVR(Task):
159
160     def descr2tensor(self, descr):
161         t = [ [ self.token2id[u] for u in s ] for s in descr ]
162         return torch.tensor(t, device = self.device)
163
164     def __init__(self, batch_size,
165                  height, width, nb_colors = 5,
166                  device = torch.device('cpu')):
167
168         def generate_descr(nb):
169             descr = picoclvr.generate(
170                 nb,
171                 height = self.height, width = self.width,
172                 nb_colors = nb_colors
173             )
174
175             descr = [ s.strip().split(' ') for s in descr ]
176             l = max([ len(s) for s in descr ])
177             #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
178             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
179
180             return descr
181
182         self.height = height
183         self.width = width
184         self.batch_size = batch_size
185         self.device = device
186         nb = args.data_size if args.data_size > 0 else 250000
187
188         self.train_descr = generate_descr((nb * 4) // 5)
189         self.test_descr = generate_descr((nb * 1) // 5)
190
191         # Build the tokenizer
192         tokens = set()
193         for d in [ self.train_descr, self.test_descr ]:
194             for s in d:
195                 for t in s: tokens.add(t)
196         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
197         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
198
199         # Tokenize the train and test sets
200         self.train_input = descr2tensor(self.train_descr)
201         self.test_input = descr2tensor(self.test_descr)
202
203     def batches(self, split = 'train'):
204         assert split in { 'train', 'test' }
205         if split == 'train':
206             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
207                 yield batch
208         else:
209             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
210                 yield batch
211
212     def vocabulary_size(self):
213         return len(self.token2id)
214
215     def generate(self, descr_primer, model, nb_tokens):
216         results = autoregression(
217             model, self.batch_size,
218             1, nb_tokens, primer = descr2tensor(descr_primer),
219             device = self.device
220         )
221         return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
222
223     def produce_results(self, n_epoch, model, nb_tokens = None):
224         if nb_tokens is None:
225             nb_tokens = self.height * self.width + 3
226         result_descr = [ ]
227         nb_per_primer = 8
228
229         for descr_primer in [
230                 'red above green <sep> green top <sep> blue right of red <img>',
231                 'there is red <sep> there is yellow <sep> there is blue <img>',
232                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
233                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
234         ]:
235
236             for k in range(nb_per_primer):
237                 result_descr.append(self.generate(descr_primer, model, nb_tokens))
238
239         img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
240                 for d in result_descr ]
241         img = torch.cat(img, 0)
242         image_name = f'result_picoclvr_{n_epoch:04d}.png'
243         torchvision.utils.save_image(
244             img / 255.,
245             image_name, nrow = nb_per_primer, pad_value = 0.8
246         )
247         log_string(f'wrote {image_name}')
248
249         np = picoclvr.nb_properties(
250             result_descr,
251             height = self.height, width = self.width
252         )
253
254         nb_requested_properties, _, nb_missing_properties = zip(*np)
255
256         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
257
258 ######################################################################
259
260 class TaskWiki103(Task):
261
262     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
263                  device = torch.device('cpu')):
264
265         self.batch_size = batch_size
266         self.len_min = len_min
267         self.len_max = len_max
268         self.min_freq = min_freq
269         self.device = device
270
271         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
272         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
273
274         # Mostly for debug
275         if args.data_size > 0:
276             train_iter = itertools.islice(train_iter, args.data_size)
277
278         def yield_tokens():
279             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
280                 yield self.tokenizer(l)
281
282         self.vocab = torchtext.vocab.build_vocab_from_iterator(
283             yield_tokens(),
284             specials = [ '<unk>', '<non>' ],
285             min_freq = self.min_freq
286         )
287
288         self.vocab.set_default_index(self.vocab[ '<unk>' ])
289
290     def tensorize(self, s):
291         a = max(len(x) for x in s)
292         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
293
294     def yield_batches(self, ds):
295         s = [ ]
296         for l in ds:
297             q = self.tokenizer(l)
298             if len(q) >= self.len_min and len(q) <= self.len_max:
299                 s += [ q ]
300                 if len(s) == self.batch_size:
301                     yield self.tensorize(s)
302                     s = [ ]
303
304         if len(s) > 0:
305             yield self.tensorize(s)
306
307     def batches(self, split = 'train'):
308         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
309
310         # Mostly for debug
311         if args.data_size > 0:
312             data_iter = itertools.islice(data_iter, args.data_size)
313
314         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
315
316     def vocabulary_size(self):
317         return len(self.vocab)
318
319     def produce_results(self, n_epoch, model, nb_tokens = 50):
320         file_name = f'result_wiki103_{n_epoch:04d}.txt'
321
322         with open(file_name, 'w') as outfile:
323              for primer in [
324                      'the cat is hunting a',
325                      'paris is the capital',
326                      'cars are convenient',
327                      'the difference between men and women is',
328                      'the object was blue all over and green all over it was',
329                      'cherries are red and lemons are',
330                      'cherries are sweet and lemons are',
331                      'two plus three equals',
332                      'deep learning is',
333              ]:
334                  t_primer = self.tokenizer(primer)
335                  t_generated = [ ]
336
337                  for j in range(nb_tokens):
338
339                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
340                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
341                      output = model(input)
342                      logits = output[0, -1]
343                      if args.synthesis_sampling:
344                          dist = torch.distributions.categorical.Categorical(logits = logits)
345                          t_next = dist.sample()
346                      else:
347                          t_next = logits.argmax()
348                      t_generated.append(self.vocab.lookup_token(t_next))
349                      if t_generated[-1] == '<non>': break
350
351                  s = ' '.join(t_generated)
352
353                  outfile.write(f'<{primer}> {s}\n')
354
355         log_string(f'wrote {file_name}')
356
357 ######################################################################
358
359 class TaskMNIST(Task):
360
361     def __init__(self, batch_size, device = torch.device('cpu')):
362         self.device = device
363         self.batch_size = batch_size
364
365     def batches(self, split = 'train'):
366         assert split in { 'train', 'test' }
367         data_set = torchvision.datasets.MNIST(
368             root = './data', train = (split == 'train'),
369             download = True
370         )
371         data_input = data_set.data.view(-1, 28 * 28).long()
372         if args.data_size >= 0:
373             data_input = data_input[:args.data_size]
374         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
375             yield batch
376
377     def vocabulary_size(self):
378         return 256
379
380     def produce_results(self, n_epoch, model, nb_samples = 64):
381         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
382         image_name = f'result_mnist_{n_epoch:04d}.png'
383         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
384                                      image_name, nrow = 16, pad_value = 0.8)
385         log_string(f'wrote {image_name}')
386
387 ######################################################################
388
389 log_string(f'device {device}')
390
391 if args.data == 'wiki103':
392     nb_epochs_default = 10
393     task = TaskWiki103(batch_size = args.batch_size, device = device)
394 elif args.data == 'mnist':
395     nb_epochs_default = 25
396     task = TaskMNIST(batch_size = args.batch_size, device = device)
397 elif args.data == 'picoclvr':
398     nb_epochs_default = 10
399     task = TaskPicoCLVR(batch_size = args.batch_size,
400                         height = args.picoclvr_height,
401                         width = args.picoclvr_width,
402                         nb_colors = args.picoclvr_nb_colors,
403                         device = device)
404 else:
405     raise ValueError(f'Unknown dataset {args.data}.')
406
407 vocabulary_size = task.vocabulary_size()
408
409 log_string(f'vocabulary_size {vocabulary_size}')
410
411 ##############################
412
413 model = mygpt.MyGPT(
414     vocabulary_size = vocabulary_size,
415     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
416     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
417 )
418
419 model.to(device)
420
421 nb_parameters = sum(p.numel() for p in model.parameters())
422 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
423
424 ######################################################################
425
426 if args.optim == 'sgd':
427     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
428 elif args.optim == 'adam':
429     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
430 elif args.optim == 'adamw':
431     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
432 else:
433     raise ValueError(f'Unknown optimizer {args.optim}.')
434
435 ######################################################################
436
437 nb_epochs_finished = 0
438
439 if args.no_checkpoint:
440     log_string(f'not trying to load checkpoint.')
441
442 else:
443     try:
444         checkpoint = torch.load(args.checkpoint_name, map_location = device)
445         nb_epochs_finished = checkpoint['nb_epochs_finished']
446         model.load_state_dict(checkpoint['model_state'])
447         optimizer.load_state_dict(checkpoint['optimizer_state'])
448         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
449
450     except FileNotFoundError:
451         log_string('starting from scratch.')
452
453     except:
454         log_string('error when loading the checkpoint.')
455         exit(1)
456
457 ######################################################################
458
459 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
460
461 token_count = 0
462 for input in task.batches(split = 'train'):
463     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
464 token_probas = token_count / token_count.sum()
465 entropy = -torch.xlogy(token_probas, token_probas).sum()
466 train_set_perplexity = math.exp(entropy)
467 #log_string(f'train set perplexity {train_set_perplexity}')
468
469 for k in range(nb_epochs_finished, nb_epochs):
470
471     model.train()
472
473     nb_train_samples, acc_train_loss = 0, 0.0
474
475     for input in task.batches(split = 'train'):
476         input = input.to(device)
477         output = model(input)
478         loss = F.cross_entropy(output.transpose(1, 2), input)
479         acc_train_loss += loss.item() * input.size(0)
480         nb_train_samples += input.size(0)
481
482         optimizer.zero_grad()
483         loss.backward()
484         optimizer.step()
485
486     with torch.autograd.no_grad():
487
488         model.eval()
489
490         nb_test_samples, acc_test_loss = 0, 0.0
491
492         for input in task.batches(split = 'test'):
493             input = input.to(device)
494             output = model(input)
495             loss = F.cross_entropy(output.transpose(1, 2), input)
496             acc_test_loss += loss.item() * input.size(0)
497             nb_test_samples += input.size(0)
498
499         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
500         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
501
502         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
503
504         task.produce_results(k, model)
505
506     checkpoint = {
507         'nb_epochs_finished': k + 1,
508         'model_state': model.state_dict(),
509         'optimizer_state': optimizer.state_dict()
510     }
511
512     torch.save(checkpoint, args.checkpoint_name)
513
514 ######################################################################