Finalized PicoCLVR with "many colors".
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     type = bool, default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     type = bool, default = True)
71
72 ######################################################################
73
74 args = parser.parse_args()
75
76 log_file = open(args.log_filename, 'w')
77
78 if args.seed >= 0:
79     torch.manual_seed(args.seed)
80
81 ######################################################################
82
83 def log_string(s):
84     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
85
86     if log_file is not None:
87         log_file.write(t + s + '\n')
88         log_file.flush()
89
90     print(t + s)
91     sys.stdout.flush()
92
93 for n in vars(args):
94     log_string(f'args.{n} {getattr(args, n)}')
95
96 ######################################################################
97
98 class Task:
99     def batches(self, split = 'train'):
100         pass
101
102     def vocabulary_size(self):
103         pass
104
105     def produce_results(self, n_epoch, model, nb_tokens = 50):
106         pass
107
108 ######################################################################
109
110 import picoclvr
111
112 class TaskPicoCLVR(Task):
113
114     def __init__(self, batch_size,
115                  height = 6, width = 8, many_colors = False,
116                  device = torch.device('cpu')):
117
118         self.batch_size = batch_size
119         self.device = device
120         nb = args.data_size if args.data_size > 0 else 250000
121
122         descr = picoclvr.generate(
123             nb,
124             height = height, width = width,
125             many_colors = many_colors
126         )
127
128         descr = [ s.strip().split(' ') for s in descr ]
129         l = max([ len(s) for s in descr ])
130         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
131
132         tokens = set()
133         for s in descr:
134             for t in s: tokens.add(t)
135         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
136         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
137
138         t = [ [ self.token2id[u] for u in s ] for s in descr ]
139         data_input = torch.tensor(t, device = self.device)
140
141         self.test_input = data_input[:nb // 5]
142         self.train_input = data_input[nb // 5:]
143
144     def batches(self, split = 'train'):
145         assert split in { 'train', 'test' }
146         if split == 'train':
147             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
148                 yield batch
149         else:
150             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
151                 yield batch
152
153     def vocabulary_size(self):
154         return len(self.token2id)
155
156     def produce_results(self, n_epoch, model, nb_tokens = 50):
157         img = [ ]
158         nb_per_primer = 8
159
160         for primer in [
161                 'red above green <sep> green top <sep> blue right of red <img>',
162                 'there is red <sep> there is yellow <sep> there is blue <img>',
163                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
164                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
165         ]:
166
167             for k in range(nb_per_primer):
168                 t_primer = primer.strip().split(' ')
169                 t_generated = [ ]
170
171                 for j in range(nb_tokens):
172                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
173                     input = torch.tensor(t, device = self.device)
174                     output = model(input)
175                     logits = output[0, -1]
176                     if args.synthesis_sampling:
177                         dist = torch.distributions.categorical.Categorical(logits = logits)
178                         t = dist.sample()
179                     else:
180                         t = logits.argmax()
181                     t_generated.append(self.id2token[t.item()])
182
183                 descr = [ ' '.join(t_primer + t_generated) ]
184                 img += [ picoclvr.descr2img(descr) ]
185
186         img = torch.cat(img, 0)
187         file_name = f'result_picoclvr_{n_epoch:04d}.png'
188         torchvision.utils.save_image(img / 255.,
189                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
190         log_string(f'wrote {file_name}')
191
192 ######################################################################
193
194 class TaskWiki103(Task):
195
196     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
197                  device = torch.device('cpu')):
198
199         self.batch_size = batch_size
200         self.len_min = len_min
201         self.len_max = len_max
202         self.min_freq = min_freq
203         self.device = device
204
205         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
206         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
207
208         # Mostly for debug
209         if args.data_size > 0:
210             train_iter = itertools.islice(train_iter, args.data_size)
211
212         def yield_tokens():
213             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
214                 yield self.tokenizer(l)
215
216         self.vocab = torchtext.vocab.build_vocab_from_iterator(
217             yield_tokens(),
218             specials = [ '<unk>', '<non>' ],
219             min_freq = self.min_freq
220         )
221
222         self.vocab.set_default_index(self.vocab[ '<unk>' ])
223
224     def tensorize(self, s):
225         a = max(len(x) for x in s)
226         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
227
228     def yield_batches(self, ds):
229         s = [ ]
230         for l in ds:
231             q = self.tokenizer(l)
232             if len(q) >= self.len_min and len(q) <= self.len_max:
233                 s += [ q ]
234                 if len(s) == self.batch_size:
235                     yield self.tensorize(s)
236                     s = [ ]
237
238         if len(s) > 0:
239             yield self.tensorize(s)
240
241     def batches(self, split = 'train'):
242         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
243
244         # Mostly for debug
245         if args.data_size > 0:
246             data_iter = itertools.islice(data_iter, args.data_size)
247
248         return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
249
250     def vocabulary_size(self):
251         return len(self.vocab)
252
253     def produce_results(self, n_epoch, model, nb_tokens = 50):
254         file_name = f'result_wiki103_{n_epoch:04d}.txt'
255
256         with open(file_name, 'w') as outfile:
257              for primer in [
258                      'the cat is hunting a',
259                      'paris is the capital',
260                      'cars are convenient',
261                      'the difference between men and women is',
262                      'the object was blue all over and green all over it was',
263                      'cherries are red and lemons are',
264                      'cherries are sweet and lemons are',
265                      'two plus three equals',
266                      'deep learning is',
267              ]:
268                  t_primer = self.tokenizer(primer)
269                  t_generated = [ ]
270
271                  for j in range(nb_tokens):
272
273                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
274                      output = model(input)
275                      logits = output[0, -1]
276                      if args.synthesis_sampling:
277                          dist = torch.distributions.categorical.Categorical(logits = logits)
278                          t = dist.sample()
279                      else:
280                          t = logits.argmax()
281                      t_generated.append(self.vocab.lookup_token(t))
282                      if t_generated[-1] == '<non>': break
283
284                  s = ' '.join(t_generated)
285
286                  outfile.write(f'<{primer}> {s}\n')
287
288         log_string(f'wrote {file_name}')
289
290 ######################################################################
291
292 class TaskMNIST(Task):
293
294     def __init__(self, batch_size, device = torch.device('cpu')):
295         self.device = device
296         self.batch_size = batch_size
297
298     def batches(self, split = 'train'):
299         assert split in { 'train', 'test' }
300         data_set = torchvision.datasets.MNIST(
301             root = './data', train = (split == 'train'),
302             download = True
303         )
304         data_input = data_set.data.view(-1, 28 * 28).long()
305         if args.data_size >= 0:
306             data_input = data_input[:args.data_size]
307         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
308             yield batch
309
310     def vocabulary_size(self):
311         return 256
312
313     def produce_results(self, n_epoch, model, nb_samples = 64):
314         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
315         for input in results.split(self.batch_size):
316             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
317                 output = model(input)
318                 logits = output[:, s]
319                 if args.synthesis_sampling:
320                     dist = torch.distributions.categorical.Categorical(logits = logits)
321                     t = dist.sample()
322                 else:
323                     t = logits.argmax(1)
324                 input[:, s + 1] = t
325
326         image_name = f'result_mnist_{n_epoch:04d}.png'
327         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
328                                      image_name, nrow = 16, pad_value = 0.8)
329         log_string(f'wrote {image_name}')
330
331 ######################################################################
332
333 def check_causality(model):
334     #m = model[1:]
335     input = torch.rand(1, 5, dim_model).requires_grad_()
336     output = m(input)
337     a = torch.zeros(output.size(1), input.size(1))
338     for k in range(output.size(1)):
339         for d in range(output.size(2)):
340             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
341             a[k] += g.squeeze(0).pow(2).sum(1)
342     print(a)
343
344 ######################################################################
345
346 log_string(f'device {device}')
347
348 if args.data == 'wiki103':
349     task = TaskWiki103(batch_size = args.batch_size, device = device)
350 elif args.data == 'mnist':
351     task = TaskMNIST(batch_size = args.batch_size, device = device)
352 elif args.data == 'picoclvr':
353     task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
354 else:
355     raise ValueError(f'Unknown dataset {args.data}.')
356
357 vocabulary_size = task.vocabulary_size()
358
359 log_string(f'vocabulary_size {vocabulary_size}')
360
361 ##############################
362
363 model = mygpt.MyGPT(
364     vocabulary_size = vocabulary_size,
365     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
366     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
367 )
368
369 nb_parameters = sum(p.numel() for p in model.parameters())
370 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
371
372 model.to(device)
373
374 ######################################################################
375
376 if args.optim == 'sgd':
377     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
378 elif args.optim == 'adam':
379     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
380 elif args.optim == 'adamw':
381     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
382 else:
383     raise ValueError(f'Unknown optimizer {args.optim}.')
384
385 for k in range(args.nb_epochs):
386
387     model.train()
388
389     nb_train_samples, acc_train_loss = 0, 0.0
390
391     for input in task.batches(split = 'train'):
392         input = input.to(device)
393         output = model(input)
394         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
395         acc_train_loss += loss.item() * input.size(0)
396         nb_train_samples += input.size(0)
397
398         optimizer.zero_grad()
399         loss.backward()
400         optimizer.step()
401
402     with torch.autograd.no_grad():
403
404         model.eval()
405
406         nb_test_samples, acc_test_loss = 0, 0.0
407
408         for input in task.batches(split = 'test'):
409             input = input.to(device)
410             output = model(input)
411             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
412             acc_test_loss += loss.item() * input.size(0)
413             nb_test_samples += input.size(0)
414
415         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
416         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
417
418         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
419
420         task.produce_results(k, model)
421
422 ######################################################################