Initialize properly w_o.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 ##############################
76 # picoclvr options
77
78 parser.add_argument('--picoclvr_many_colors',
79                     action='store_true', default = False)
80
81 parser.add_argument('--picoclvr_height',
82                     type = int, default = 12)
83
84 parser.add_argument('--picoclvr_width',
85                     type = int, default = 16)
86
87 ######################################################################
88
89 args = parser.parse_args()
90
91 log_file = open(args.log_filename, 'w')
92
93 if args.seed >= 0:
94     torch.manual_seed(args.seed)
95
96 ######################################################################
97
98 def log_string(s):
99     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100
101     if log_file is not None:
102         log_file.write(t + s + '\n')
103         log_file.flush()
104
105     print(t + s)
106     sys.stdout.flush()
107
108 for n in vars(args):
109     log_string(f'args.{n} {getattr(args, n)}')
110
111 ######################################################################
112
113 class Task:
114     def batches(self, split = 'train'):
115         pass
116
117     def vocabulary_size(self):
118         pass
119
120     def produce_results(self, n_epoch, model, nb_tokens = 50):
121         pass
122
123 ######################################################################
124
125 import picoclvr
126
127 class TaskPicoCLVR(Task):
128
129     def __init__(self, batch_size,
130                  height, width, many_colors = False,
131                  device = torch.device('cpu')):
132
133         self.height = height
134         self.width = width
135         self.batch_size = batch_size
136         self.device = device
137         nb = args.data_size if args.data_size > 0 else 250000
138
139         descr = picoclvr.generate(
140             nb,
141             height = self.height, width = self.width,
142             many_colors = many_colors
143         )
144
145         # self.test_descr = descr[:nb // 5]
146         # self.train_descr = descr[nb // 5:]
147
148         descr = [ s.strip().split(' ') for s in descr ]
149         l = max([ len(s) for s in descr ])
150         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
151
152         tokens = set()
153         for s in descr:
154             for t in s: tokens.add(t)
155         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
156         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
157
158         t = [ [ self.token2id[u] for u in s ] for s in descr ]
159         data_input = torch.tensor(t, device = self.device)
160
161         self.test_input = data_input[:nb // 5]
162         self.train_input = data_input[nb // 5:]
163
164     def batches(self, split = 'train'):
165         assert split in { 'train', 'test' }
166         if split == 'train':
167             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
168                 yield batch
169         else:
170             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
171                 yield batch
172
173     def vocabulary_size(self):
174         return len(self.token2id)
175
176     def generate(self, primer, model, nb_tokens):
177         t_primer = primer.strip().split(' ')
178         t_generated = [ ]
179
180         for j in range(nb_tokens):
181             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
182             input = torch.tensor(t, device = self.device)
183             output = model(input)
184             logits = output[0, -1]
185             if args.synthesis_sampling:
186                 dist = torch.distributions.categorical.Categorical(logits = logits)
187                 t = dist.sample()
188             else:
189                 t = logits.argmax()
190             t_generated.append(self.id2token[t.item()])
191
192         return ' '.join(t_primer + t_generated)
193
194     def produce_results(self, n_epoch, model, nb_tokens = None):
195         if nb_tokens is None:
196             nb_tokens = self.height * self.width + 3
197         descr = [ ]
198         nb_per_primer = 8
199
200         for primer in [
201                 'red above green <sep> green top <sep> blue right of red <img>',
202                 'there is red <sep> there is yellow <sep> there is blue <img>',
203                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
204                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
205         ]:
206
207             for k in range(nb_per_primer):
208                 descr.append(self.generate(primer, model, nb_tokens))
209
210         img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
211         img = torch.cat(img, 0)
212         file_name = f'result_picoclvr_{n_epoch:04d}.png'
213         torchvision.utils.save_image(
214             img / 255.,
215             file_name, nrow = nb_per_primer, pad_value = 0.8
216         )
217         log_string(f'wrote {file_name}')
218
219         nb_missing = sum( [
220             x[2] for x in picoclvr.nb_missing_properties(
221                 descr,
222                 height = self.height, width = self.width
223             )
224         ] )
225
226         log_string(f'nb_missing {nb_missing / len(descr):.02f}')
227
228 ######################################################################
229
230 class TaskWiki103(Task):
231
232     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
233                  device = torch.device('cpu')):
234
235         self.batch_size = batch_size
236         self.len_min = len_min
237         self.len_max = len_max
238         self.min_freq = min_freq
239         self.device = device
240
241         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
242         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
243
244         # Mostly for debug
245         if args.data_size > 0:
246             train_iter = itertools.islice(train_iter, args.data_size)
247
248         def yield_tokens():
249             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
250                 yield self.tokenizer(l)
251
252         self.vocab = torchtext.vocab.build_vocab_from_iterator(
253             yield_tokens(),
254             specials = [ '<unk>', '<non>' ],
255             min_freq = self.min_freq
256         )
257
258         self.vocab.set_default_index(self.vocab[ '<unk>' ])
259
260     def tensorize(self, s):
261         a = max(len(x) for x in s)
262         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
263
264     def yield_batches(self, ds):
265         s = [ ]
266         for l in ds:
267             q = self.tokenizer(l)
268             if len(q) >= self.len_min and len(q) <= self.len_max:
269                 s += [ q ]
270                 if len(s) == self.batch_size:
271                     yield self.tensorize(s)
272                     s = [ ]
273
274         if len(s) > 0:
275             yield self.tensorize(s)
276
277     def batches(self, split = 'train'):
278         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
279
280         # Mostly for debug
281         if args.data_size > 0:
282             data_iter = itertools.islice(data_iter, args.data_size)
283
284         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
285
286     def vocabulary_size(self):
287         return len(self.vocab)
288
289     def produce_results(self, n_epoch, model, nb_tokens = 50):
290         file_name = f'result_wiki103_{n_epoch:04d}.txt'
291
292         with open(file_name, 'w') as outfile:
293              for primer in [
294                      'the cat is hunting a',
295                      'paris is the capital',
296                      'cars are convenient',
297                      'the difference between men and women is',
298                      'the object was blue all over and green all over it was',
299                      'cherries are red and lemons are',
300                      'cherries are sweet and lemons are',
301                      'two plus three equals',
302                      'deep learning is',
303              ]:
304                  t_primer = self.tokenizer(primer)
305                  t_generated = [ ]
306
307                  for j in range(nb_tokens):
308
309                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
310                      output = model(input)
311                      logits = output[0, -1]
312                      if args.synthesis_sampling:
313                          dist = torch.distributions.categorical.Categorical(logits = logits)
314                          t = dist.sample()
315                      else:
316                          t = logits.argmax()
317                      t_generated.append(self.vocab.lookup_token(t))
318                      if t_generated[-1] == '<non>': break
319
320                  s = ' '.join(t_generated)
321
322                  outfile.write(f'<{primer}> {s}\n')
323
324         log_string(f'wrote {file_name}')
325
326 ######################################################################
327
328 class TaskMNIST(Task):
329
330     def __init__(self, batch_size, device = torch.device('cpu')):
331         self.device = device
332         self.batch_size = batch_size
333
334     def batches(self, split = 'train'):
335         assert split in { 'train', 'test' }
336         data_set = torchvision.datasets.MNIST(
337             root = './data', train = (split == 'train'),
338             download = True
339         )
340         data_input = data_set.data.view(-1, 28 * 28).long()
341         if args.data_size >= 0:
342             data_input = data_input[:args.data_size]
343         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
344             yield batch
345
346     def vocabulary_size(self):
347         return 256
348
349     def produce_results(self, n_epoch, model, nb_samples = 64):
350         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
351         for input in results.split(self.batch_size):
352             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
353                 output = model(input)
354                 logits = output[:, s]
355                 if args.synthesis_sampling:
356                     dist = torch.distributions.categorical.Categorical(logits = logits)
357                     t = dist.sample()
358                 else:
359                     t = logits.argmax(1)
360                 input[:, s + 1] = t
361
362         image_name = f'result_mnist_{n_epoch:04d}.png'
363         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
364                                      image_name, nrow = 16, pad_value = 0.8)
365         log_string(f'wrote {image_name}')
366
367 ######################################################################
368
369 def check_causality(model):
370     #m = model[1:]
371     input = torch.rand(1, 5, dim_model).requires_grad_()
372     output = m(input)
373     a = torch.zeros(output.size(1), input.size(1))
374     for k in range(output.size(1)):
375         for d in range(output.size(2)):
376             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
377             a[k] += g.squeeze(0).pow(2).sum(1)
378     print(a)
379
380 ######################################################################
381
382 log_string(f'device {device}')
383
384 if args.data == 'wiki103':
385     task = TaskWiki103(batch_size = args.batch_size, device = device)
386 elif args.data == 'mnist':
387     task = TaskMNIST(batch_size = args.batch_size, device = device)
388 elif args.data == 'picoclvr':
389     task = TaskPicoCLVR(batch_size = args.batch_size,
390                         height = args.picoclvr_height,
391                         width = args.picoclvr_width,
392                         many_colors = args.picoclvr_many_colors,
393                         device = device)
394 else:
395     raise ValueError(f'Unknown dataset {args.data}.')
396
397 vocabulary_size = task.vocabulary_size()
398
399 log_string(f'vocabulary_size {vocabulary_size}')
400
401 ##############################
402
403 model = mygpt.MyGPT(
404     vocabulary_size = vocabulary_size,
405     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
406     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
407 )
408
409 model.to(device)
410
411 nb_parameters = sum(p.numel() for p in model.parameters())
412 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
413
414 ######################################################################
415
416 if args.optim == 'sgd':
417     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
418 elif args.optim == 'adam':
419     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
420 elif args.optim == 'adamw':
421     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
422 else:
423     raise ValueError(f'Unknown optimizer {args.optim}.')
424
425 ######################################################################
426
427 nb_epochs_finished = 0
428
429 try:
430     checkpoint = torch.load(args.checkpoint_name, map_location = device)
431     nb_epochs_finished = checkpoint['nb_epochs_finished']
432     model.load_state_dict(checkpoint['model_state'])
433     optimizer.load_state_dict(checkpoint['optimizer_state'])
434     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
435
436 except FileNotFoundError:
437     print('Starting from scratch.')
438
439 except:
440     print('Error when loading the checkpoint.')
441     exit(1)
442
443 ######################################################################
444
445 for k in range(nb_epochs_finished, args.nb_epochs):
446
447     model.train()
448
449     nb_train_samples, acc_train_loss = 0, 0.0
450
451     for input in task.batches(split = 'train'):
452         input = input.to(device)
453         output = model(input)
454         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
455         acc_train_loss += loss.item() * input.size(0)
456         nb_train_samples += input.size(0)
457
458         optimizer.zero_grad()
459         loss.backward()
460         optimizer.step()
461
462     with torch.autograd.no_grad():
463
464         model.eval()
465
466         nb_test_samples, acc_test_loss = 0, 0.0
467
468         for input in task.batches(split = 'test'):
469             input = input.to(device)
470             output = model(input)
471             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
472             acc_test_loss += loss.item() * input.size(0)
473             nb_test_samples += input.size(0)
474
475         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
476         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
477
478         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
479
480         task.produce_results(k, model)
481
482     checkpoint = {
483         'nb_epochs_finished': k + 1,
484         'model_state': model.state_dict(),
485         'optimizer_state': optimizer.state_dict()
486     }
487
488     torch.save(checkpoint, args.checkpoint_name)
489
490 ######################################################################