OCDC
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
22
23 parser.add_argument('--log_filename',
24                     type = str, default = 'train.log')
25
26 parser.add_argument('--seed',
27                     type = int, default = 0)
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = -1)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 25)
34
35 parser.add_argument('--data',
36                     type = str, default = 'wiki103')
37
38 parser.add_argument('--data_size',
39                     type = int, default = -1)
40
41 parser.add_argument('--optim',
42                     type = str, default = 'adam')
43
44 parser.add_argument('--learning_rate',
45                     type = float, default = 1e-4)
46
47 parser.add_argument('--dim_model',
48                     type = int, default = 512)
49
50 parser.add_argument('--dim_keys',
51                     type = int, default = 64)
52
53 parser.add_argument('--dim_hidden',
54                     type = int, default = 2048)
55
56 parser.add_argument('--nb_heads',
57                     type = int, default = 8)
58
59 parser.add_argument('--nb_blocks',
60                     type = int, default = 12)
61
62 parser.add_argument('--dropout',
63                     type = float, default = 0.1)
64
65 parser.add_argument('--synthesis_sampling',
66                     action='store_true', default = True)
67
68 parser.add_argument('--no_checkpoint',
69                     action='store_true', default = False)
70
71 parser.add_argument('--checkpoint_name',
72                     type = str, default = 'checkpoint.pth')
73
74 ##############################
75 # picoclvr options
76
77 parser.add_argument('--picoclvr_nb_colors',
78                     type = int, default = 5)
79
80 parser.add_argument('--picoclvr_height',
81                     type = int, default = 12)
82
83 parser.add_argument('--picoclvr_width',
84                     type = int, default = 16)
85
86 ######################################################################
87
88 args = parser.parse_args()
89
90 log_file = open(args.log_filename, 'w')
91
92 if args.seed >= 0:
93     torch.manual_seed(args.seed)
94
95 ######################################################################
96
97 def log_string(s):
98     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
99
100     if log_file is not None:
101         log_file.write(t + s + '\n')
102         log_file.flush()
103
104     print(t + s)
105     sys.stdout.flush()
106
107 for n in vars(args):
108     log_string(f'args.{n} {getattr(args, n)}')
109
110 ######################################################################
111
112 def autoregression(
113         model, batch_size,
114         nb_samples, nb_tokens_to_generate, primer = None,
115         device = torch.device('cpu')
116 ):
117     results = torch.zeros(
118         nb_samples, nb_tokens_to_generate,
119         dtype = torch.int64, device = device
120     )
121
122     if primer is None:
123         first = 0
124     else:
125         first = primer.size(1)
126         results = torch.cat((primer, results), 1)
127
128     for input in results.split(batch_size):
129         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
130             output = model(input)
131             logits = output[:, s]
132             if args.synthesis_sampling:
133                 dist = torch.distributions.categorical.Categorical(logits = logits)
134                 t_next = dist.sample()
135             else:
136                 t_next = logits.argmax(1)
137             input[:, s] = t_next
138
139     return results
140
141 ######################################################################
142
143 class Task:
144     def batches(self, split = 'train'):
145         pass
146
147     def vocabulary_size(self):
148         pass
149
150     def produce_results(self, n_epoch, model):
151         pass
152
153 ######################################################################
154
155 import picoclvr
156
157 class TaskPicoCLVR(Task):
158
159     def descr2tensor(self, descr):
160         t = [ [ self.token2id[u] for u in s ] for s in descr ]
161         return torch.tensor(t, device = self.device)
162
163     def __init__(self, batch_size,
164                  height, width, nb_colors = 5,
165                  device = torch.device('cpu')):
166
167         def generate_descr(nb):
168             descr = picoclvr.generate(
169                 nb,
170                 height = self.height, width = self.width,
171                 nb_colors = nb_colors
172             )
173
174             descr = [ s.strip().split(' ') for s in descr ]
175             l = max([ len(s) for s in descr ])
176             #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
177             descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
178
179             return descr
180
181         self.height = height
182         self.width = width
183         self.batch_size = batch_size
184         self.device = device
185         nb = args.data_size if args.data_size > 0 else 250000
186
187         self.train_descr = generate_descr((nb * 4) // 5)
188         self.test_descr = generate_descr((nb * 1) // 5)
189
190         # Build the tokenizer
191         tokens = set()
192         for d in [ self.train_descr, self.test_descr ]:
193             for s in d:
194                 for t in s: tokens.add(t)
195         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
196         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
197
198         # Tokenize the train and test sets
199         self.train_input = descr2tensor(self.train_descr)
200         self.test_input = descr2tensor(self.test_descr)
201
202     def batches(self, split = 'train'):
203         assert split in { 'train', 'test' }
204         input = self.train_input if split == 'train' else self.test_input
205         for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
206             yield batch
207
208     def vocabulary_size(self):
209         return len(self.token2id)
210
211     def generate(self, primer_descr, model, nb_tokens):
212         results = autoregression(
213             model, self.batch_size,
214             nb_samples = 1, nb_tokens = nb_tokens, primer = descr2tensor(primer_descr),
215             device = self.device
216         )
217         return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
218
219     def produce_results(self, n_epoch, model):
220         nb_tokens = self.height * self.width + 3
221         result_descr = [ ]
222         nb_per_primer = 8
223
224         for primer_descr in [
225                 'red above green <sep> green top <sep> blue right of red <img>',
226                 'there is red <sep> there is yellow <sep> there is blue <img>',
227                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
228                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
229         ]:
230
231             for k in range(nb_per_primer):
232                 result_descr.append(self.generate(primer_descr, model, nb_tokens))
233
234         img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
235                 for d in result_descr ]
236         img = torch.cat(img, 0)
237         image_name = f'result_picoclvr_{n_epoch:04d}.png'
238         torchvision.utils.save_image(
239             img / 255.,
240             image_name, nrow = nb_per_primer, pad_value = 0.8
241         )
242         log_string(f'wrote {image_name}')
243
244         np = picoclvr.nb_properties(
245             result_descr,
246             height = self.height, width = self.width
247         )
248
249         nb_requested_properties, _, nb_missing_properties = zip(*np)
250
251         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
252
253 ######################################################################
254
255 class TaskWiki103(Task):
256
257     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
258                  device = torch.device('cpu')):
259
260         self.batch_size = batch_size
261         self.len_min = len_min
262         self.len_max = len_max
263         self.min_freq = min_freq
264         self.device = device
265
266         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
267         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
268
269         # Mostly for debug
270         if args.data_size > 0:
271             train_iter = itertools.islice(train_iter, args.data_size)
272
273         def yield_tokens():
274             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
275                 yield self.tokenizer(l)
276
277         self.vocab = torchtext.vocab.build_vocab_from_iterator(
278             yield_tokens(),
279             specials = [ '<unk>', '<non>' ],
280             min_freq = self.min_freq
281         )
282
283         self.vocab.set_default_index(self.vocab[ '<unk>' ])
284
285     def tensorize(self, s):
286         a = max(len(x) for x in s)
287         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
288
289     def yield_batches(self, ds):
290         s = [ ]
291         for l in ds:
292             q = self.tokenizer(l)
293             if len(q) >= self.len_min and len(q) <= self.len_max:
294                 s += [ q ]
295                 if len(s) == self.batch_size:
296                     yield self.tensorize(s)
297                     s = [ ]
298
299         if len(s) > 0:
300             yield self.tensorize(s)
301
302     def batches(self, split = 'train'):
303         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
304
305         # Mostly for debug
306         if args.data_size > 0:
307             data_iter = itertools.islice(data_iter, args.data_size)
308
309         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
310
311     def vocabulary_size(self):
312         return len(self.vocab)
313
314     def produce_results(self, n_epoch, model):
315         nb_tokens = 50
316         file_name = f'result_wiki103_{n_epoch:04d}.txt'
317
318         with open(file_name, 'w') as outfile:
319              for primer in [
320                      'the cat is hunting a',
321                      'paris is the capital',
322                      'cars are convenient',
323                      'the difference between men and women is',
324                      'the object was blue all over and green all over it was',
325                      'cherries are red and lemons are',
326                      'cherries are sweet and lemons are',
327                      'two plus three equals',
328                      'deep learning is',
329              ]:
330                  t_primer = self.tokenizer(primer)
331                  t_generated = [ ]
332
333                  for j in range(nb_tokens):
334
335                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
336                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
337                      output = model(input)
338                      logits = output[0, -1]
339                      if args.synthesis_sampling:
340                          dist = torch.distributions.categorical.Categorical(logits = logits)
341                          t_next = dist.sample()
342                      else:
343                          t_next = logits.argmax()
344                      t_generated.append(self.vocab.lookup_token(t_next))
345                      if t_generated[-1] == '<non>': break
346
347                  s = ' '.join(t_generated)
348
349                  outfile.write(f'<{primer}> {s}\n')
350
351         log_string(f'wrote {file_name}')
352
353 ######################################################################
354
355 class TaskMNIST(Task):
356
357     def __init__(self, batch_size, device = torch.device('cpu')):
358         self.device = device
359         self.batch_size = batch_size
360
361     def batches(self, split = 'train'):
362         assert split in { 'train', 'test' }
363         data_set = torchvision.datasets.MNIST(
364             root = './data', train = (split == 'train'),
365             download = True
366         )
367         data_input = data_set.data.view(-1, 28 * 28).long()
368         if args.data_size >= 0:
369             data_input = data_input[:args.data_size]
370         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
371             yield batch
372
373     def vocabulary_size(self):
374         return 256
375
376     def produce_results(self, n_epoch, model):
377         nb_samples = 64
378         results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
379         image_name = f'result_mnist_{n_epoch:04d}.png'
380         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
381                                      image_name, nrow = 16, pad_value = 0.8)
382         log_string(f'wrote {image_name}')
383
384 ######################################################################
385
386 log_string(f'device {device}')
387
388 if args.data == 'wiki103':
389     nb_epochs_default = 10
390     task = TaskWiki103(batch_size = args.batch_size, device = device)
391 elif args.data == 'mnist':
392     nb_epochs_default = 25
393     task = TaskMNIST(batch_size = args.batch_size, device = device)
394 elif args.data == 'picoclvr':
395     nb_epochs_default = 10
396     task = TaskPicoCLVR(batch_size = args.batch_size,
397                         height = args.picoclvr_height,
398                         width = args.picoclvr_width,
399                         nb_colors = args.picoclvr_nb_colors,
400                         device = device)
401 else:
402     raise ValueError(f'Unknown dataset {args.data}.')
403
404 vocabulary_size = task.vocabulary_size()
405
406 log_string(f'vocabulary_size {vocabulary_size}')
407
408 ##############################
409
410 model = mygpt.MyGPT(
411     vocabulary_size = vocabulary_size,
412     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
413     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
414 )
415
416 model.to(device)
417
418 nb_parameters = sum(p.numel() for p in model.parameters())
419 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
420
421 ######################################################################
422
423 if args.optim == 'sgd':
424     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
425 elif args.optim == 'adam':
426     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adamw':
428     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
429 else:
430     raise ValueError(f'Unknown optimizer {args.optim}.')
431
432 ######################################################################
433
434 nb_epochs_finished = 0
435
436 if args.no_checkpoint:
437     log_string(f'not trying to load checkpoint.')
438
439 else:
440     try:
441         checkpoint = torch.load(args.checkpoint_name, map_location = device)
442         nb_epochs_finished = checkpoint['nb_epochs_finished']
443         model.load_state_dict(checkpoint['model_state'])
444         optimizer.load_state_dict(checkpoint['optimizer_state'])
445         log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
446
447     except FileNotFoundError:
448         log_string('starting from scratch.')
449
450     except:
451         log_string('error when loading the checkpoint.')
452         exit(1)
453
454 ######################################################################
455
456 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
457
458 token_count = 0
459 for input in task.batches(split = 'train'):
460     token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
461 token_probas = token_count / token_count.sum()
462 entropy = -torch.xlogy(token_probas, token_probas).sum()
463 train_set_perplexity = math.exp(entropy)
464 #log_string(f'train set perplexity {train_set_perplexity}')
465
466 for k in range(nb_epochs_finished, nb_epochs):
467
468     model.train()
469
470     nb_train_samples, acc_train_loss = 0, 0.0
471
472     for input in task.batches(split = 'train'):
473         input = input.to(device)
474         output = model(input)
475         loss = F.cross_entropy(output.transpose(1, 2), input)
476         acc_train_loss += loss.item() * input.size(0)
477         nb_train_samples += input.size(0)
478
479         optimizer.zero_grad()
480         loss.backward()
481         optimizer.step()
482
483     with torch.autograd.no_grad():
484
485         model.eval()
486
487         nb_test_samples, acc_test_loss = 0, 0.0
488
489         for input in task.batches(split = 'test'):
490             input = input.to(device)
491             output = model(input)
492             loss = F.cross_entropy(output.transpose(1, 2), input)
493             acc_test_loss += loss.item() * input.size(0)
494             nb_test_samples += input.size(0)
495
496         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
497         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
498
499         log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
500
501         task.produce_results(k, model)
502
503     checkpoint = {
504         'nb_epochs_finished': k + 1,
505         'model_state': model.state_dict(),
506         'optimizer_state': optimizer.state_dict()
507     }
508
509     torch.save(checkpoint, args.checkpoint_name)
510
511 ######################################################################