OCDC
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in range(first, input.size(1)):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     # Make a tensor from a list of strings
160     def tensorize(self, descr):
161         token_descr = [ s.strip().split(' ') for s in descr ]
162         l = max([ len(s) for s in token_descr ])
163         #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164         token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165         id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166         return torch.tensor(id_descr, device = self.device)
167
168     def __init__(self, batch_size,
169                  height, width, nb_colors = 5,
170                  device = torch.device('cpu')):
171
172         def generate_descr(nb):
173             return picoclvr.generate(
174                 nb,
175                 height = self.height, width = self.width,
176                 nb_colors = nb_colors
177             )
178
179         self.height = height
180         self.width = width
181         self.batch_size = batch_size
182         self.device = device
183         nb = args.data_size if args.data_size > 0 else 250000
184
185         self.train_descr = generate_descr((nb * 4) // 5)
186         self.test_descr = generate_descr((nb * 1) // 5)
187
188         # Build the tokenizer
189         tokens = { '<nul>' }
190         for d in [ self.train_descr, self.test_descr ]:
191             for s in d:
192                 for t in s.strip().split(' '): tokens.add(t)
193         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
194         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
195
196         # Tokenize the train and test sets
197         self.train_input = self.tensorize(self.train_descr)
198         self.test_input = self.tensorize(self.test_descr)
199
200     def batches(self, split = 'train'):
201         assert split in { 'train', 'test' }
202         input = self.train_input if split == 'train' else self.test_input
203         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
204             yield batch
205
206     def vocabulary_size(self):
207         return len(self.token2id)
208
209     def produce_results(self, n_epoch, model):
210         nb_tokens = self.height * self.width + 3
211         result_descr = [ ]
212         nb_per_primer = 8
213
214         for primer_descr in [
215                 'red above green <sep> green top <sep> blue right of red <img>',
216                 'there is red <sep> there is yellow <sep> there is blue <img>',
217                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
218                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
219         ]:
220
221             for k in range(nb_per_primer):
222                 results = autoregression(
223                     model, self.batch_size,
224                     nb_samples = 1, nb_tokens_to_generate = nb_tokens,
225                     primer = self.tensorize([ primer_descr ]),
226                     device = self.device
227                 )
228                 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
229                 result_descr.append(r)
230
231         img = [
232             picoclvr.descr2img(d, height = self.height, width = self.width)
233             for d in result_descr
234         ]
235
236         img = torch.cat(img, 0)
237         image_name = f'result_picoclvr_{n_epoch:04d}.png'
238         torchvision.utils.save_image(
239             img / 255.,
240             image_name, nrow = nb_per_primer, pad_value = 0.8
241         )
242         log_string(f'wrote {image_name}')
243
244         np = picoclvr.nb_properties(
245             result_descr,
246             height = self.height, width = self.width
247         )
248
249         nb_requested_properties, _, nb_missing_properties = zip(*np)
250
251         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
252
253 ######################################################################
254
255 class TaskWiki103(Task):
256
257     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
258                  device = torch.device('cpu')):
259
260         self.batch_size = batch_size
261         self.len_min = len_min
262         self.len_max = len_max
263         self.min_freq = min_freq
264         self.device = device
265
266         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
267         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
268
269         # Mostly for debug
270         if args.data_size > 0:
271             train_iter = itertools.islice(train_iter, args.data_size)
272
273         def yield_tokens():
274             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
275                 yield self.tokenizer(l)
276
277         self.vocab = torchtext.vocab.build_vocab_from_iterator(
278             yield_tokens(),
279             specials = [ '<unk>', '<nul>' ],
280             min_freq = self.min_freq
281         )
282
283         self.vocab.set_default_index(self.vocab[ '<unk>' ])
284
285     # makes a tensor from a list of list of tokens
286     def tensorize(self, s):
287         a = max(len(x) for x in s)
288         return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
289
290     def yield_batches(self, ds):
291         s = [ ]
292         for l in ds:
293             q = self.tokenizer(l)
294             if len(q) >= self.len_min and len(q) <= self.len_max:
295                 s += [ q ]
296                 if len(s) == self.batch_size:
297                     yield self.tensorize(s)
298                     s = [ ]
299
300         if len(s) > 0:
301             yield self.tensorize(s)
302
303     def batches(self, split = 'train'):
304         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
305
306         # Mostly for debug
307         if args.data_size > 0:
308             data_iter = itertools.islice(data_iter, args.data_size)
309
310         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
311
312     def vocabulary_size(self):
313         return len(self.vocab)
314
315     def produce_results(self, n_epoch, model):
316         nb_tokens = 50
317         file_name = f'result_wiki103_{n_epoch:04d}.txt'
318
319         with open(file_name, 'w') as outfile:
320              for primer in [
321                      'the cat is hunting a',
322                      'paris is the capital',
323                      'cars are convenient',
324                      'the difference between men and women is',
325                      'the object was blue all over and green all over it was',
326                      'cherries are red and lemons are',
327                      'cherries are sweet and lemons are',
328                      'two plus three equals',
329                      'deep learning is',
330              ]:
331                  t_primer = self.tokenizer(primer)
332                  t_generated = [ ]
333
334                  for j in range(nb_tokens):
335
336                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
337                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
338                      output = model(input)
339                      logits = output[0, -1]
340                      if args.synthesis_sampling:
341                          dist = torch.distributions.categorical.Categorical(logits = logits)
342                          t_next = dist.sample()
343                      else:
344                          t_next = logits.argmax()
345                      t_generated.append(self.vocab.lookup_token(t_next))
346                      if t_generated[-1] == '<nul>': break
347
348                  s = ' '.join(t_generated)
349
350                  outfile.write(f'<{primer}> {s}\n')
351
352         log_string(f'wrote {file_name}')
353
354 ######################################################################
355
356 class TaskMNIST(Task):
357
358     def __init__(self, batch_size, device = torch.device('cpu')):
359         self.device = device
360         self.batch_size = batch_size
361
362     def batches(self, split = 'train'):
363         assert split in { 'train', 'test' }
364         data_set = torchvision.datasets.MNIST(
365             root = './data', train = (split == 'train'),
366             download = True
367         )
368         data_input = data_set.data.view(-1, 28 * 28).long()
369         if args.data_size >= 0:
370             data_input = data_input[:args.data_size]
371         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
372             yield batch
373
374     def vocabulary_size(self):
375         return 256
376
377     def produce_results(self, n_epoch, model):
378         nb_samples = 64
379         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
380         image_name = f'result_mnist_{n_epoch:04d}.png'
381         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
382                                      image_name, nrow = 16, pad_value = 0.8)
383         log_string(f'wrote {image_name}')
384
385 ######################################################################
386
387 log_string(f'device {device}')
388
389 if args.data == 'wiki103':
390     nb_epochs_default = 10
391     task = TaskWiki103(batch_size = args.batch_size, device = device)
392 elif args.data == 'mnist':
393     nb_epochs_default = 25
394     task = TaskMNIST(batch_size = args.batch_size, device = device)
395 elif args.data == 'picoclvr':
396     nb_epochs_default = 10
397     task = TaskPicoCLVR(batch_size = args.batch_size,
398                         height = args.picoclvr_height,
399                         width = args.picoclvr_width,
400                         nb_colors = args.picoclvr_nb_colors,
401                         device = device)
402 else:
403     raise ValueError(f'Unknown dataset {args.data}.')
404
405 vocabulary_size = task.vocabulary_size()
406
407 log_string(f'vocabulary_size {vocabulary_size}')
408
409 ##############################
410
411 model = mygpt.MyGPT(
412     vocabulary_size = vocabulary_size,
413     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
414     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
415 )
416
417 model.to(device)
418
419 nb_parameters = sum(p.numel() for p in model.parameters())
420 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
421
422 ######################################################################
423
424 if args.optim == 'sgd':
425     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adam':
427     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
428 elif args.optim == 'adamw':
429     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
430 else:
431     raise ValueError(f'Unknown optimizer {args.optim}.')
432
433 ######################################################################
434
435 nb_epochs_finished = 0
436
437 if args.no_checkpoint:
438     log_string(f'not trying to load checkpoint.')
439
440 else:
441     try:
442         checkpoint = torch.load(args.checkpoint_name, map_location = device)
443         nb_epochs_finished = checkpoint['nb_epochs_finished']
444         model.load_state_dict(checkpoint['model_state'])
445         optimizer.load_state_dict(checkpoint['optimizer_state'])
446         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
447
448     except FileNotFoundError:
449         log_string('starting from scratch.')
450
451     except:
452         log_string('error when loading the checkpoint.')
453         exit(1)
454
455 ######################################################################
456
457 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
458
459 token_count = 0
460 for input in task.batches(split = 'train'):
461     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
462 token_probas = token_count / token_count.sum()
463 entropy = -torch.xlogy(token_probas, token_probas).sum()
464 train_set_perplexity = math.exp(entropy)
465 #log_string(f'train set perplexity {train_set_perplexity}')
466
467 for k in range(nb_epochs_finished, nb_epochs):
468
469     model.train()
470
471     nb_train_samples, acc_train_loss = 0, 0.0
472
473     for input in task.batches(split = 'train'):
474         input = input.to(device)
475         output = model(input)
476         loss = F.cross_entropy(output.transpose(1, 2), input)
477         acc_train_loss += loss.item() * input.size(0)
478         nb_train_samples += input.size(0)
479
480         optimizer.zero_grad()
481         loss.backward()
482         optimizer.step()
483
484     with torch.autograd.no_grad():
485
486         model.eval()
487
488         nb_test_samples, acc_test_loss = 0, 0.0
489
490         for input in task.batches(split = 'test'):
491             input = input.to(device)
492             output = model(input)
493             loss = F.cross_entropy(output.transpose(1, 2), input)
494             acc_test_loss += loss.item() * input.size(0)
495             nb_test_samples += input.size(0)
496
497         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
498         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
499
500         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
501
502         task.produce_results(k, model)
503
504     checkpoint = {
505         'nb_epochs_finished': k + 1,
506         'model_state': model.state_dict(),
507         'optimizer_state': optimizer.state_dict()
508     }
509
510     torch.save(checkpoint, args.checkpoint_name)
511
512 ######################################################################