Fixed stuff.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in range(first, input.size(1)):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     def tensorize(self, descr):
160         descr = [ s.strip().split(' ') for s in descr ]
161         l = max([ len(s) for s in descr ])
162         #descr = [ [ '<nul>' ] * (l - len(s)) + s for s in descr ]
163         descr = [ s + [ '<nul>' ] * (l - len(s)) for s in descr ]
164         t = [ [ self.token2id[u] for u in s ] for s in descr ]
165         return torch.tensor(t, device = self.device)
166
167     def __init__(self, batch_size,
168                  height, width, nb_colors = 5,
169                  device = torch.device('cpu')):
170
171         def generate_descr(nb):
172             return picoclvr.generate(
173                 nb,
174                 height = self.height, width = self.width,
175                 nb_colors = nb_colors
176             )
177
178         self.height = height
179         self.width = width
180         self.batch_size = batch_size
181         self.device = device
182         nb = args.data_size if args.data_size > 0 else 250000
183
184         self.train_descr = generate_descr((nb * 4) // 5)
185         self.test_descr = generate_descr((nb * 1) // 5)
186
187         # Build the tokenizer
188         tokens = { '<nul>' }
189         for d in [ self.train_descr, self.test_descr ]:
190             for s in d:
191                 for t in s.strip().split(' '): tokens.add(t)
192         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
193         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
194
195         # Tokenize the train and test sets
196         self.train_input = self.tensorize(self.train_descr)
197         self.test_input = self.tensorize(self.test_descr)
198
199     def batches(self, split = 'train'):
200         assert split in { 'train', 'test' }
201         input = self.train_input if split == 'train' else self.test_input
202         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
203             yield batch
204
205     def vocabulary_size(self):
206         return len(self.token2id)
207
208     def produce_results(self, n_epoch, model):
209         nb_tokens = self.height * self.width + 3
210         result_descr = [ ]
211         nb_per_primer = 8
212
213         for primer_descr in [
214                 'red above green <sep> green top <sep> blue right of red <img>',
215                 'there is red <sep> there is yellow <sep> there is blue <img>',
216                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
217                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
218         ]:
219
220             for k in range(nb_per_primer):
221                 results = autoregression(
222                     model, self.batch_size,
223                     nb_samples = 1, nb_tokens_to_generate = nb_tokens,
224                     primer = self.tensorize([ primer_descr ]),
225                     device = self.device
226                 )
227                 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
228                 result_descr.append(r)
229
230         img = [
231             picoclvr.descr2img(d, height = self.height, width = self.width)
232             for d in result_descr
233         ]
234
235         img = torch.cat(img, 0)
236         image_name = f'result_picoclvr_{n_epoch:04d}.png'
237         torchvision.utils.save_image(
238             img / 255.,
239             image_name, nrow = nb_per_primer, pad_value = 0.8
240         )
241         log_string(f'wrote {image_name}')
242
243         np = picoclvr.nb_properties(
244             result_descr,
245             height = self.height, width = self.width
246         )
247
248         nb_requested_properties, _, nb_missing_properties = zip(*np)
249
250         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
251
252 ######################################################################
253
254 class TaskWiki103(Task):
255
256     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
257                  device = torch.device('cpu')):
258
259         self.batch_size = batch_size
260         self.len_min = len_min
261         self.len_max = len_max
262         self.min_freq = min_freq
263         self.device = device
264
265         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
266         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
267
268         # Mostly for debug
269         if args.data_size > 0:
270             train_iter = itertools.islice(train_iter, args.data_size)
271
272         def yield_tokens():
273             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
274                 yield self.tokenizer(l)
275
276         self.vocab = torchtext.vocab.build_vocab_from_iterator(
277             yield_tokens(),
278             specials = [ '<unk>', '<nul>' ],
279             min_freq = self.min_freq
280         )
281
282         self.vocab.set_default_index(self.vocab[ '<unk>' ])
283
284     def tensorize(self, s):
285         a = max(len(x) for x in s)
286         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
287
288     def yield_batches(self, ds):
289         s = [ ]
290         for l in ds:
291             q = self.tokenizer(l)
292             if len(q) >= self.len_min and len(q) <= self.len_max:
293                 s += [ q ]
294                 if len(s) == self.batch_size:
295                     yield self.tensorize(s)
296                     s = [ ]
297
298         if len(s) > 0:
299             yield self.tensorize(s)
300
301     def batches(self, split = 'train'):
302         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
303
304         # Mostly for debug
305         if args.data_size > 0:
306             data_iter = itertools.islice(data_iter, args.data_size)
307
308         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
309
310     def vocabulary_size(self):
311         return len(self.vocab)
312
313     def produce_results(self, n_epoch, model):
314         nb_tokens = 50
315         file_name = f'result_wiki103_{n_epoch:04d}.txt'
316
317         with open(file_name, 'w') as outfile:
318              for primer in [
319                      'the cat is hunting a',
320                      'paris is the capital',
321                      'cars are convenient',
322                      'the difference between men and women is',
323                      'the object was blue all over and green all over it was',
324                      'cherries are red and lemons are',
325                      'cherries are sweet and lemons are',
326                      'two plus three equals',
327                      'deep learning is',
328              ]:
329                  t_primer = self.tokenizer(primer)
330                  t_generated = [ ]
331
332                  for j in range(nb_tokens):
333
334                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
335                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
336                      output = model(input)
337                      logits = output[0, -1]
338                      if args.synthesis_sampling:
339                          dist = torch.distributions.categorical.Categorical(logits = logits)
340                          t_next = dist.sample()
341                      else:
342                          t_next = logits.argmax()
343                      t_generated.append(self.vocab.lookup_token(t_next))
344                      if t_generated[-1] == '<nul>': break
345
346                  s = ' '.join(t_generated)
347
348                  outfile.write(f'<{primer}> {s}\n')
349
350         log_string(f'wrote {file_name}')
351
352 ######################################################################
353
354 class TaskMNIST(Task):
355
356     def __init__(self, batch_size, device = torch.device('cpu')):
357         self.device = device
358         self.batch_size = batch_size
359
360     def batches(self, split = 'train'):
361         assert split in { 'train', 'test' }
362         data_set = torchvision.datasets.MNIST(
363             root = './data', train = (split == 'train'),
364             download = True
365         )
366         data_input = data_set.data.view(-1, 28 * 28).long()
367         if args.data_size >= 0:
368             data_input = data_input[:args.data_size]
369         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
370             yield batch
371
372     def vocabulary_size(self):
373         return 256
374
375     def produce_results(self, n_epoch, model):
376         nb_samples = 64
377         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
378         image_name = f'result_mnist_{n_epoch:04d}.png'
379         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
380                                      image_name, nrow = 16, pad_value = 0.8)
381         log_string(f'wrote {image_name}')
382
383 ######################################################################
384
385 log_string(f'device {device}')
386
387 if args.data == 'wiki103':
388     nb_epochs_default = 10
389     task = TaskWiki103(batch_size = args.batch_size, device = device)
390 elif args.data == 'mnist':
391     nb_epochs_default = 25
392     task = TaskMNIST(batch_size = args.batch_size, device = device)
393 elif args.data == 'picoclvr':
394     nb_epochs_default = 10
395     task = TaskPicoCLVR(batch_size = args.batch_size,
396                         height = args.picoclvr_height,
397                         width = args.picoclvr_width,
398                         nb_colors = args.picoclvr_nb_colors,
399                         device = device)
400 else:
401     raise ValueError(f'Unknown dataset {args.data}.')
402
403 vocabulary_size = task.vocabulary_size()
404
405 log_string(f'vocabulary_size {vocabulary_size}')
406
407 ##############################
408
409 model = mygpt.MyGPT(
410     vocabulary_size = vocabulary_size,
411     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
412     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
413 )
414
415 model.to(device)
416
417 nb_parameters = sum(p.numel() for p in model.parameters())
418 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
419
420 ######################################################################
421
422 if args.optim == 'sgd':
423     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
424 elif args.optim == 'adam':
425     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adamw':
427     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
428 else:
429     raise ValueError(f'Unknown optimizer {args.optim}.')
430
431 ######################################################################
432
433 nb_epochs_finished = 0
434
435 if args.no_checkpoint:
436     log_string(f'not trying to load checkpoint.')
437
438 else:
439     try:
440         checkpoint = torch.load(args.checkpoint_name, map_location = device)
441         nb_epochs_finished = checkpoint['nb_epochs_finished']
442         model.load_state_dict(checkpoint['model_state'])
443         optimizer.load_state_dict(checkpoint['optimizer_state'])
444         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
445
446     except FileNotFoundError:
447         log_string('starting from scratch.')
448
449     except:
450         log_string('error when loading the checkpoint.')
451         exit(1)
452
453 ######################################################################
454
455 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
456
457 token_count = 0
458 for input in task.batches(split = 'train'):
459     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
460 token_probas = token_count / token_count.sum()
461 entropy = -torch.xlogy(token_probas, token_probas).sum()
462 train_set_perplexity = math.exp(entropy)
463 #log_string(f'train set perplexity {train_set_perplexity}')
464
465 for k in range(nb_epochs_finished, nb_epochs):
466
467     model.train()
468
469     nb_train_samples, acc_train_loss = 0, 0.0
470
471     for input in task.batches(split = 'train'):
472         input = input.to(device)
473         output = model(input)
474         loss = F.cross_entropy(output.transpose(1, 2), input)
475         acc_train_loss += loss.item() * input.size(0)
476         nb_train_samples += input.size(0)
477
478         optimizer.zero_grad()
479         loss.backward()
480         optimizer.step()
481
482     with torch.autograd.no_grad():
483
484         model.eval()
485
486         nb_test_samples, acc_test_loss = 0, 0.0
487
488         for input in task.batches(split = 'test'):
489             input = input.to(device)
490             output = model(input)
491             loss = F.cross_entropy(output.transpose(1, 2), input)
492             acc_test_loss += loss.item() * input.size(0)
493             nb_test_samples += input.size(0)
494
495         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
496         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
497
498         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
499
500         task.produce_results(k, model)
501
502     checkpoint = {
503         'nb_epochs_finished': k + 1,
504         'model_state': model.state_dict(),
505         'optimizer_state': optimizer.state_dict()
506     }
507
508     torch.save(checkpoint, args.checkpoint_name)
509
510 ######################################################################