OCDC
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     def tensorize(self, descr):
160         t = [ [ self.token2id[u] for u in s ] for s in descr ]
161         return torch.tensor(t, device = self.device)
162
163     def __init__(self, batch_size,
164                  height, width, nb_colors = 5,
165                  device = torch.device('cpu')):
166
167         def generate_descr(nb):
168             descr = picoclvr.generate(
169                 nb,
170                 height = self.height, width = self.width,
171                 nb_colors = nb_colors
172             )
173
174             descr = [ s.strip().split(' ') for s in descr ]
175             l = max([ len(s) for s in descr ])
176             #descr = [ [ '<nul>' ] * (l - len(s)) + s for s in descr ]
177             descr = [ s + [ '<nul>' ] * (l - len(s)) for s in descr ]
178
179             return descr
180
181         self.height = height
182         self.width = width
183         self.batch_size = batch_size
184         self.device = device
185         nb = args.data_size if args.data_size > 0 else 250000
186
187         self.train_descr = generate_descr((nb * 4) // 5)
188         self.test_descr = generate_descr((nb * 1) // 5)
189
190         # Build the tokenizer
191         tokens = set()
192         for d in [ self.train_descr, self.test_descr ]:
193             for s in d:
194                 for t in s: tokens.add(t)
195         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
196         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
197
198         # Tokenize the train and test sets
199         self.train_input = self.tensorize(self.train_descr)
200         self.test_input = self.tensorize(self.test_descr)
201
202     def batches(self, split = 'train'):
203         assert split in { 'train', 'test' }
204         input = self.train_input if split == 'train' else self.test_input
205         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
206             yield batch
207
208     def vocabulary_size(self):
209         return len(self.token2id)
210
211     def produce_results(self, n_epoch, model):
212         nb_tokens = self.height * self.width + 3
213         result_descr = [ ]
214         nb_per_primer = 8
215
216         for primer_descr in [
217                 'red above green <sep> green top <sep> blue right of red <img>',
218                 'there is red <sep> there is yellow <sep> there is blue <img>',
219                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
220                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
221         ]:
222
223             for k in range(nb_per_primer):
224                 results = autoregression(
225                     model, self.batch_size,
226                     nb_samples = 1, nb_tokens = nb_tokens,
227                     primer = self.tensorize(primer_descr),
228                     device = self.device
229                 )
230                 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
231                 result_descr.append(r)
232
233         img = [
234             picoclvr.descr2img(d, height = self.height, width = self.width)
235             for d in result_descr
236         ]
237
238         img = torch.cat(img, 0)
239         image_name = f'result_picoclvr_{n_epoch:04d}.png'
240         torchvision.utils.save_image(
241             img / 255.,
242             image_name, nrow = nb_per_primer, pad_value = 0.8
243         )
244         log_string(f'wrote {image_name}')
245
246         np = picoclvr.nb_properties(
247             result_descr,
248             height = self.height, width = self.width
249         )
250
251         nb_requested_properties, _, nb_missing_properties = zip(*np)
252
253         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
254
255 ######################################################################
256
257 class TaskWiki103(Task):
258
259     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
260                  device = torch.device('cpu')):
261
262         self.batch_size = batch_size
263         self.len_min = len_min
264         self.len_max = len_max
265         self.min_freq = min_freq
266         self.device = device
267
268         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
269         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
270
271         # Mostly for debug
272         if args.data_size > 0:
273             train_iter = itertools.islice(train_iter, args.data_size)
274
275         def yield_tokens():
276             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
277                 yield self.tokenizer(l)
278
279         self.vocab = torchtext.vocab.build_vocab_from_iterator(
280             yield_tokens(),
281             specials = [ '<unk>', '<nul>' ],
282             min_freq = self.min_freq
283         )
284
285         self.vocab.set_default_index(self.vocab[ '<unk>' ])
286
287     def tensorize(self, s):
288         a = max(len(x) for x in s)
289         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
290
291     def yield_batches(self, ds):
292         s = [ ]
293         for l in ds:
294             q = self.tokenizer(l)
295             if len(q) >= self.len_min and len(q) <= self.len_max:
296                 s += [ q ]
297                 if len(s) == self.batch_size:
298                     yield self.tensorize(s)
299                     s = [ ]
300
301         if len(s) > 0:
302             yield self.tensorize(s)
303
304     def batches(self, split = 'train'):
305         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
306
307         # Mostly for debug
308         if args.data_size > 0:
309             data_iter = itertools.islice(data_iter, args.data_size)
310
311         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
312
313     def vocabulary_size(self):
314         return len(self.vocab)
315
316     def produce_results(self, n_epoch, model):
317         nb_tokens = 50
318         file_name = f'result_wiki103_{n_epoch:04d}.txt'
319
320         with open(file_name, 'w') as outfile:
321              for primer in [
322                      'the cat is hunting a',
323                      'paris is the capital',
324                      'cars are convenient',
325                      'the difference between men and women is',
326                      'the object was blue all over and green all over it was',
327                      'cherries are red and lemons are',
328                      'cherries are sweet and lemons are',
329                      'two plus three equals',
330                      'deep learning is',
331              ]:
332                  t_primer = self.tokenizer(primer)
333                  t_generated = [ ]
334
335                  for j in range(nb_tokens):
336
337                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
338                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
339                      output = model(input)
340                      logits = output[0, -1]
341                      if args.synthesis_sampling:
342                          dist = torch.distributions.categorical.Categorical(logits = logits)
343                          t_next = dist.sample()
344                      else:
345                          t_next = logits.argmax()
346                      t_generated.append(self.vocab.lookup_token(t_next))
347                      if t_generated[-1] == '<nul>': break
348
349                  s = ' '.join(t_generated)
350
351                  outfile.write(f'<{primer}> {s}\n')
352
353         log_string(f'wrote {file_name}')
354
355 ######################################################################
356
357 class TaskMNIST(Task):
358
359     def __init__(self, batch_size, device = torch.device('cpu')):
360         self.device = device
361         self.batch_size = batch_size
362
363     def batches(self, split = 'train'):
364         assert split in { 'train', 'test' }
365         data_set = torchvision.datasets.MNIST(
366             root = './data', train = (split == 'train'),
367             download = True
368         )
369         data_input = data_set.data.view(-1, 28 * 28).long()
370         if args.data_size >= 0:
371             data_input = data_input[:args.data_size]
372         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
373             yield batch
374
375     def vocabulary_size(self):
376         return 256
377
378     def produce_results(self, n_epoch, model):
379         nb_samples = 64
380         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
381         image_name = f'result_mnist_{n_epoch:04d}.png'
382         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
383                                      image_name, nrow = 16, pad_value = 0.8)
384         log_string(f'wrote {image_name}')
385
386 ######################################################################
387
388 log_string(f'device {device}')
389
390 if args.data == 'wiki103':
391     nb_epochs_default = 10
392     task = TaskWiki103(batch_size = args.batch_size, device = device)
393 elif args.data == 'mnist':
394     nb_epochs_default = 25
395     task = TaskMNIST(batch_size = args.batch_size, device = device)
396 elif args.data == 'picoclvr':
397     nb_epochs_default = 10
398     task = TaskPicoCLVR(batch_size = args.batch_size,
399                         height = args.picoclvr_height,
400                         width = args.picoclvr_width,
401                         nb_colors = args.picoclvr_nb_colors,
402                         device = device)
403 else:
404     raise ValueError(f'Unknown dataset {args.data}.')
405
406 vocabulary_size = task.vocabulary_size()
407
408 log_string(f'vocabulary_size {vocabulary_size}')
409
410 ##############################
411
412 model = mygpt.MyGPT(
413     vocabulary_size = vocabulary_size,
414     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
415     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
416 )
417
418 model.to(device)
419
420 nb_parameters = sum(p.numel() for p in model.parameters())
421 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
422
423 ######################################################################
424
425 if args.optim == 'sgd':
426     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adam':
428     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
429 elif args.optim == 'adamw':
430     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
431 else:
432     raise ValueError(f'Unknown optimizer {args.optim}.')
433
434 ######################################################################
435
436 nb_epochs_finished = 0
437
438 if args.no_checkpoint:
439     log_string(f'not trying to load checkpoint.')
440
441 else:
442     try:
443         checkpoint = torch.load(args.checkpoint_name, map_location = device)
444         nb_epochs_finished = checkpoint['nb_epochs_finished']
445         model.load_state_dict(checkpoint['model_state'])
446         optimizer.load_state_dict(checkpoint['optimizer_state'])
447         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
448
449     except FileNotFoundError:
450         log_string('starting from scratch.')
451
452     except:
453         log_string('error when loading the checkpoint.')
454         exit(1)
455
456 ######################################################################
457
458 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
459
460 token_count = 0
461 for input in task.batches(split = 'train'):
462     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
463 token_probas = token_count / token_count.sum()
464 entropy = -torch.xlogy(token_probas, token_probas).sum()
465 train_set_perplexity = math.exp(entropy)
466 #log_string(f'train set perplexity {train_set_perplexity}')
467
468 for k in range(nb_epochs_finished, nb_epochs):
469
470     model.train()
471
472     nb_train_samples, acc_train_loss = 0, 0.0
473
474     for input in task.batches(split = 'train'):
475         input = input.to(device)
476         output = model(input)
477         loss = F.cross_entropy(output.transpose(1, 2), input)
478         acc_train_loss += loss.item() * input.size(0)
479         nb_train_samples += input.size(0)
480
481         optimizer.zero_grad()
482         loss.backward()
483         optimizer.step()
484
485     with torch.autograd.no_grad():
486
487         model.eval()
488
489         nb_test_samples, acc_test_loss = 0, 0.0
490
491         for input in task.batches(split = 'test'):
492             input = input.to(device)
493             output = model(input)
494             loss = F.cross_entropy(output.transpose(1, 2), input)
495             acc_test_loss += loss.item() * input.size(0)
496             nb_test_samples += input.size(0)
497
498         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
499         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
500
501         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
502
503         task.produce_results(k, model)
504
505     checkpoint = {
506         'nb_epochs_finished': k + 1,
507         'model_state': model.state_dict(),
508         'optimizer_state': optimizer.state_dict()
509     }
510
511     torch.save(checkpoint, args.checkpoint_name)
512
513 ######################################################################