Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     action='store_true', default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     action='store_true', default = True)
71
72 parser.add_argument('--checkpoint_name',
73                     type = str, default = 'checkpoint.pth')
74
75 parser.add_argument('--picoclvr_many_colors',
76                     action='store_true', default = False)
77
78 ######################################################################
79
80 args = parser.parse_args()
81
82 log_file = open(args.log_filename, 'w')
83
84 if args.seed >= 0:
85     torch.manual_seed(args.seed)
86
87 ######################################################################
88
89 def log_string(s):
90     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
91
92     if log_file is not None:
93         log_file.write(t + s + '\n')
94         log_file.flush()
95
96     print(t + s)
97     sys.stdout.flush()
98
99 for n in vars(args):
100     log_string(f'args.{n} {getattr(args, n)}')
101
102 ######################################################################
103
104 class Task:
105     def batches(self, split = 'train'):
106         pass
107
108     def vocabulary_size(self):
109         pass
110
111     def produce_results(self, n_epoch, model, nb_tokens = 50):
112         pass
113
114 ######################################################################
115
116 import picoclvr
117
118 class TaskPicoCLVR(Task):
119
120     def __init__(self, batch_size,
121                  height = 6, width = 8, many_colors = False,
122                  device = torch.device('cpu')):
123
124         self.batch_size = batch_size
125         self.device = device
126         nb = args.data_size if args.data_size > 0 else 250000
127
128         descr = picoclvr.generate(
129             nb,
130             height = height, width = width,
131             many_colors = many_colors
132         )
133
134         descr = [ s.strip().split(' ') for s in descr ]
135         l = max([ len(s) for s in descr ])
136         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
137
138         tokens = set()
139         for s in descr:
140             for t in s: tokens.add(t)
141         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
142         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
143
144         t = [ [ self.token2id[u] for u in s ] for s in descr ]
145         data_input = torch.tensor(t, device = self.device)
146
147         self.test_input = data_input[:nb // 5]
148         self.train_input = data_input[nb // 5:]
149
150     def batches(self, split = 'train'):
151         assert split in { 'train', 'test' }
152         if split == 'train':
153             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
154                 yield batch
155         else:
156             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
157                 yield batch
158
159     def vocabulary_size(self):
160         return len(self.token2id)
161
162     def produce_results(self, n_epoch, model, nb_tokens = 50):
163         img = [ ]
164         nb_per_primer = 8
165
166         for primer in [
167                 'red above green <sep> green top <sep> blue right of red <img>',
168                 'there is red <sep> there is yellow <sep> there is blue <img>',
169                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
170                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
171         ]:
172
173             for k in range(nb_per_primer):
174                 t_primer = primer.strip().split(' ')
175                 t_generated = [ ]
176
177                 for j in range(nb_tokens):
178                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
179                     input = torch.tensor(t, device = self.device)
180                     output = model(input)
181                     logits = output[0, -1]
182                     if args.synthesis_sampling:
183                         dist = torch.distributions.categorical.Categorical(logits = logits)
184                         t = dist.sample()
185                     else:
186                         t = logits.argmax()
187                     t_generated.append(self.id2token[t.item()])
188
189                 descr = [ ' '.join(t_primer + t_generated) ]
190                 img += [ picoclvr.descr2img(descr) ]
191
192         img = torch.cat(img, 0)
193         file_name = f'result_picoclvr_{n_epoch:04d}.png'
194         torchvision.utils.save_image(img / 255.,
195                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
196         log_string(f'wrote {file_name}')
197
198 ######################################################################
199
200 class TaskWiki103(Task):
201
202     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
203                  device = torch.device('cpu')):
204
205         self.batch_size = batch_size
206         self.len_min = len_min
207         self.len_max = len_max
208         self.min_freq = min_freq
209         self.device = device
210
211         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
212         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
213
214         # Mostly for debug
215         if args.data_size > 0:
216             train_iter = itertools.islice(train_iter, args.data_size)
217
218         def yield_tokens():
219             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
220                 yield self.tokenizer(l)
221
222         self.vocab = torchtext.vocab.build_vocab_from_iterator(
223             yield_tokens(),
224             specials = [ '<unk>', '<non>' ],
225             min_freq = self.min_freq
226         )
227
228         self.vocab.set_default_index(self.vocab[ '<unk>' ])
229
230     def tensorize(self, s):
231         a = max(len(x) for x in s)
232         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
233
234     def yield_batches(self, ds):
235         s = [ ]
236         for l in ds:
237             q = self.tokenizer(l)
238             if len(q) >= self.len_min and len(q) <= self.len_max:
239                 s += [ q ]
240                 if len(s) == self.batch_size:
241                     yield self.tensorize(s)
242                     s = [ ]
243
244         if len(s) > 0:
245             yield self.tensorize(s)
246
247     def batches(self, split = 'train'):
248         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
249
250         # Mostly for debug
251         if args.data_size > 0:
252             data_iter = itertools.islice(data_iter, args.data_size)
253
254         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
255
256     def vocabulary_size(self):
257         return len(self.vocab)
258
259     def produce_results(self, n_epoch, model, nb_tokens = 50):
260         file_name = f'result_wiki103_{n_epoch:04d}.txt'
261
262         with open(file_name, 'w') as outfile:
263              for primer in [
264                      'the cat is hunting a',
265                      'paris is the capital',
266                      'cars are convenient',
267                      'the difference between men and women is',
268                      'the object was blue all over and green all over it was',
269                      'cherries are red and lemons are',
270                      'cherries are sweet and lemons are',
271                      'two plus three equals',
272                      'deep learning is',
273              ]:
274                  t_primer = self.tokenizer(primer)
275                  t_generated = [ ]
276
277                  for j in range(nb_tokens):
278
279                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
280                      output = model(input)
281                      logits = output[0, -1]
282                      if args.synthesis_sampling:
283                          dist = torch.distributions.categorical.Categorical(logits = logits)
284                          t = dist.sample()
285                      else:
286                          t = logits.argmax()
287                      t_generated.append(self.vocab.lookup_token(t))
288                      if t_generated[-1] == '<non>': break
289
290                  s = ' '.join(t_generated)
291
292                  outfile.write(f'<{primer}> {s}\n')
293
294         log_string(f'wrote {file_name}')
295
296 ######################################################################
297
298 class TaskMNIST(Task):
299
300     def __init__(self, batch_size, device = torch.device('cpu')):
301         self.device = device
302         self.batch_size = batch_size
303
304     def batches(self, split = 'train'):
305         assert split in { 'train', 'test' }
306         data_set = torchvision.datasets.MNIST(
307             root = './data', train = (split == 'train'),
308             download = True
309         )
310         data_input = data_set.data.view(-1, 28 * 28).long()
311         if args.data_size >= 0:
312             data_input = data_input[:args.data_size]
313         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
314             yield batch
315
316     def vocabulary_size(self):
317         return 256
318
319     def produce_results(self, n_epoch, model, nb_samples = 64):
320         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
321         for input in results.split(self.batch_size):
322             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
323                 output = model(input)
324                 logits = output[:, s]
325                 if args.synthesis_sampling:
326                     dist = torch.distributions.categorical.Categorical(logits = logits)
327                     t = dist.sample()
328                 else:
329                     t = logits.argmax(1)
330                 input[:, s + 1] = t
331
332         image_name = f'result_mnist_{n_epoch:04d}.png'
333         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
334                                      image_name, nrow = 16, pad_value = 0.8)
335         log_string(f'wrote {image_name}')
336
337 ######################################################################
338
339 def check_causality(model):
340     #m = model[1:]
341     input = torch.rand(1, 5, dim_model).requires_grad_()
342     output = m(input)
343     a = torch.zeros(output.size(1), input.size(1))
344     for k in range(output.size(1)):
345         for d in range(output.size(2)):
346             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
347             a[k] += g.squeeze(0).pow(2).sum(1)
348     print(a)
349
350 ######################################################################
351
352 log_string(f'device {device}')
353
354 if args.data == 'wiki103':
355     task = TaskWiki103(batch_size = args.batch_size, device = device)
356 elif args.data == 'mnist':
357     task = TaskMNIST(batch_size = args.batch_size, device = device)
358 elif args.data == 'picoclvr':
359     task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
360 else:
361     raise ValueError(f'Unknown dataset {args.data}.')
362
363 vocabulary_size = task.vocabulary_size()
364
365 log_string(f'vocabulary_size {vocabulary_size}')
366
367 ##############################
368
369 model = mygpt.MyGPT(
370     vocabulary_size = vocabulary_size,
371     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
372     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
373 )
374
375 model.to(device)
376
377 nb_parameters = sum(p.numel() for p in model.parameters())
378 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
379
380 ######################################################################
381
382 if args.optim == 'sgd':
383     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
384 elif args.optim == 'adam':
385     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
386 elif args.optim == 'adamw':
387     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
388 else:
389     raise ValueError(f'Unknown optimizer {args.optim}.')
390
391 ######################################################################
392
393 nb_epochs_finished = 0
394
395 try:
396     checkpoint = torch.load(args.checkpoint_name, map_location = device)
397     nb_epochs_finished = checkpoint['nb_epochs_finished']
398     model.load_state_dict(checkpoint['model_state'])
399     optimizer.load_state_dict(checkpoint['optimizer_state'])
400     print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
401
402 except FileNotFoundError:
403     print('Starting from scratch.')
404
405 except:
406     print('Error when loading the checkpoint.')
407     exit(1)
408
409 ######################################################################
410
411 for k in range(nb_epochs_finished, args.nb_epochs):
412
413     model.train()
414
415     nb_train_samples, acc_train_loss = 0, 0.0
416
417     for input in task.batches(split = 'train'):
418         input = input.to(device)
419         output = model(input)
420         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
421         acc_train_loss += loss.item() * input.size(0)
422         nb_train_samples += input.size(0)
423
424         optimizer.zero_grad()
425         loss.backward()
426         optimizer.step()
427
428     with torch.autograd.no_grad():
429
430         model.eval()
431
432         nb_test_samples, acc_test_loss = 0, 0.0
433
434         for input in task.batches(split = 'test'):
435             input = input.to(device)
436             output = model(input)
437             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
438             acc_test_loss += loss.item() * input.size(0)
439             nb_test_samples += input.size(0)
440
441         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
442         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
443
444         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
445
446         task.produce_results(k, model)
447
448     checkpoint = {
449         'nb_epochs_finished': k + 1,
450         'model_state': model.state_dict(),
451         'optimizer_state': optimizer.state_dict()
452     }
453
454     torch.save(checkpoint, args.checkpoint_name)
455
456 ######################################################################