Update.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20 ######################################################################
21
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
23
24 parser.add_argument('--log_filename',
25                     type = str, default = 'train.log')
26
27 parser.add_argument('--download',
28                     type = bool, default = False)
29
30 parser.add_argument('--seed',
31                     type = int, default = 0)
32
33 parser.add_argument('--nb_epochs',
34                     type = int, default = 100)
35
36 parser.add_argument('--batch_size',
37                     type = int, default = 25)
38
39 parser.add_argument('--data',
40                     type = str, default = 'wiki103')
41
42 parser.add_argument('--data_size',
43                     type = int, default = -1)
44
45 parser.add_argument('--optim',
46                     type = str, default = 'adam')
47
48 parser.add_argument('--learning_rate',
49                     type = float, default = 1e-4)
50
51 parser.add_argument('--dim_model',
52                     type = int, default = 512)
53
54 parser.add_argument('--dim_keys',
55                     type = int, default = 64)
56
57 parser.add_argument('--dim_hidden',
58                     type = int, default = 2048)
59
60 parser.add_argument('--nb_heads',
61                     type = int, default = 8)
62
63 parser.add_argument('--nb_blocks',
64                     type = int, default = 12)
65
66 parser.add_argument('--dropout',
67                     type = float, default = 0.1)
68
69 parser.add_argument('--synthesis_sampling',
70                     type = bool, default = True)
71
72 ######################################################################
73
74 args = parser.parse_args()
75
76 log_file = open(args.log_filename, 'w')
77
78 if args.seed >= 0:
79     torch.manual_seed(args.seed)
80
81 ######################################################################
82
83 def log_string(s):
84     t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
85
86     if log_file is not None:
87         log_file.write(t + s + '\n')
88         log_file.flush()
89
90     print(t + s)
91     sys.stdout.flush()
92
93 for n in vars(args):
94     log_string(f'args.{n} {getattr(args, n)}')
95
96 ######################################################################
97
98 class Task:
99     def batches(self, split = 'train'):
100         pass
101
102     def vocabulary_size(self):
103         pass
104
105     def produce_results(self, n_epoch, model, nb_tokens = 50):
106         pass
107
108 ######################################################################
109
110 import picoclvr
111
112 class TaskPicoCLVR(Task):
113
114     def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
115         self.batch_size = batch_size
116         self.device = device
117         nb = args.data_size if args.data_size > 0 else 250000
118
119         descr = picoclvr.generate(nb, height = height, width = width)
120         descr = [ s.strip().split(' ') for s in descr ]
121         l = max([ len(s) for s in descr ])
122         descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
123
124         tokens = set()
125         for s in descr:
126             for t in s: tokens.add(t)
127         self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
128         self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
129
130         t = [ [ self.token2id[u] for u in s ] for s in descr ]
131         data_input = torch.tensor(t, device = self.device)
132
133         self.test_input = data_input[:nb // 5]
134         self.train_input = data_input[nb // 5:]
135
136     def batches(self, split = 'train'):
137         assert split in { 'train', 'test' }
138         if split == 'train':
139             for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
140                 yield batch
141         else:
142             for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
143                 yield batch
144
145     def vocabulary_size(self):
146         return len(self.token2id)
147
148     def produce_results(self, n_epoch, model, nb_tokens = 50):
149         img = [ ]
150         nb_per_primer = 8
151
152         for primer in [
153                 'red above green <sep> green top <sep> blue right of red <img>',
154                 'there is red <sep> there is yellow <sep> there is blue <img>',
155                 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
156                 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
157         ]:
158
159             for k in range(nb_per_primer):
160                 t_primer = primer.strip().split(' ')
161                 t_generated = [ ]
162
163                 for j in range(nb_tokens):
164                     t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
165                     input = torch.tensor(t, device = self.device)
166                     output = model(input)
167                     logits = output[0, -1]
168                     if args.synthesis_sampling:
169                         dist = torch.distributions.categorical.Categorical(logits = logits)
170                         t = dist.sample()
171                     else:
172                         t = logits.argmax()
173                     t_generated.append(self.id2token[t.item()])
174
175                 descr = [ ' '.join(t_primer + t_generated) ]
176                 img += [ picoclvr.descr2img(descr) ]
177
178         img = torch.cat(img, 0)
179         file_name = f'result_picoclvr_{n_epoch:04d}.png'
180         torchvision.utils.save_image(img / 255.,
181                                      file_name, nrow = nb_per_primer, pad_value = 0.8)
182         log_string(f'wrote {file_name}')
183
184 ######################################################################
185
186 class TaskWiki103(Task):
187
188     def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
189                  device = torch.device('cpu')):
190
191         self.batch_size = batch_size
192         self.len_min = len_min
193         self.len_max = len_max
194         self.min_freq = min_freq
195         self.device = device
196
197         self.tokenizer = torchtext.data.get_tokenizer('basic_english')
198         train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
199
200         # Mostly for debug
201         if args.data_size > 0:
202             train_iter = itertools.islice(train_iter, args.data_size)
203
204         def yield_tokens():
205             for l in tqdm.tqdm(train_iter, desc = 'vocab'):
206                 yield self.tokenizer(l)
207
208         self.vocab = torchtext.vocab.build_vocab_from_iterator(
209             yield_tokens(),
210             specials = [ '<unk>', '<non>' ],
211             min_freq = self.min_freq
212         )
213
214         self.vocab.set_default_index(self.vocab[ '<unk>' ])
215
216     def tensorize(self, s):
217         a = max(len(x) for x in s)
218         return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
219
220     def yield_batches(self, ds):
221         s = [ ]
222         for l in ds:
223             q = self.tokenizer(l)
224             if len(q) >= self.len_min and len(q) <= self.len_max:
225                 s += [ q ]
226                 if len(s) == self.batch_size:
227                     yield self.tensorize(s)
228                     s = [ ]
229
230         if len(s) > 0:
231             yield self.tensorize(s)
232
233     def batches(self, split = 'train'):
234         data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
235
236         # Mostly for debug
237         if args.data_size > 0:
238             data_iter = itertools.islice(data_iter, args.data_size)
239
240         return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
241
242     def vocabulary_size(self):
243         return len(self.vocab)
244
245     def produce_results(self, n_epoch, model, nb_tokens = 50):
246         file_name = f'result_wiki103_{n_epoch:04d}.txt'
247
248         with open(file_name, 'w') as outfile:
249              for primer in [
250                      'the cat is hunting a',
251                      'paris is the capital',
252                      'cars are convenient',
253                      'the difference between men and women is',
254                      'the object was blue all over and green all over it was',
255                      'cherries are red and lemons are',
256                      'cherries are sweet and lemons are',
257                      'two plus three equals',
258                      'deep learning is',
259              ]:
260                  t_primer = self.tokenizer(primer)
261                  t_generated = [ ]
262
263                  for j in range(nb_tokens):
264
265                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
266                      output = model(input)
267                      logits = output[0, -1]
268                      if args.synthesis_sampling:
269                          dist = torch.distributions.categorical.Categorical(logits = logits)
270                          t = dist.sample()
271                      else:
272                          t = logits.argmax()
273                      t_generated.append(self.vocab.lookup_token(t))
274                      if t_generated[-1] == '<non>': break
275
276                  s = ' '.join(t_generated)
277
278                  outfile.write(f'<{primer}> {s}\n')
279
280         log_string(f'wrote {file_name}')
281
282 ######################################################################
283
284 class TaskMNIST(Task):
285
286     def __init__(self, batch_size, device = torch.device('cpu')):
287         self.device = device
288         self.batch_size = batch_size
289
290     def batches(self, split = 'train'):
291         assert split in { 'train', 'test' }
292         data_set = torchvision.datasets.MNIST(
293             root = './data', train = (split == 'train'),
294             download = True
295         )
296         data_input = data_set.data.view(-1, 28 * 28).long()
297         if args.data_size >= 0:
298             data_input = data_input[:args.data_size]
299         for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
300             yield batch
301
302     def vocabulary_size(self):
303         return 256
304
305     def produce_results(self, n_epoch, model, nb_samples = 64):
306         results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
307         for input in results.split(self.batch_size):
308             for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
309                 output = model(input)
310                 logits = output[:, s]
311                 if args.synthesis_sampling:
312                     dist = torch.distributions.categorical.Categorical(logits = logits)
313                     t = dist.sample()
314                 else:
315                     t = logits.argmax(1)
316                 input[:, s + 1] = t
317
318         image_name = f'result_mnist_{n_epoch:04d}.png'
319         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
320                                      image_name, nrow = 16, pad_value = 0.8)
321         log_string(f'wrote {image_name}')
322
323 ######################################################################
324
325 def check_causality(model):
326     #m = model[1:]
327     input = torch.rand(1, 5, dim_model).requires_grad_()
328     output = m(input)
329     a = torch.zeros(output.size(1), input.size(1))
330     for k in range(output.size(1)):
331         for d in range(output.size(2)):
332             g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
333             a[k] += g.squeeze(0).pow(2).sum(1)
334     print(a)
335
336 ######################################################################
337
338 log_string(f'device {device}')
339
340 if args.data == 'wiki103':
341     task = TaskWiki103(batch_size = args.batch_size, device = device)
342 elif args.data == 'mnist':
343     task = TaskMNIST(batch_size = args.batch_size, device = device)
344 elif args.data == 'picoclvr':
345     task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
346 else:
347     raise ValueError(f'Unknown dataset {args.data}.')
348
349 vocabulary_size = task.vocabulary_size()
350
351 log_string(f'vocabulary_size {vocabulary_size}')
352
353 ##############################
354
355 model = mygpt.MyGPT(
356     vocabulary_size = vocabulary_size,
357     dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
358     nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
359 )
360
361 nb_parameters = sum(p.numel() for p in model.parameters())
362 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
363
364 model.to(device)
365
366 ######################################################################
367
368 if args.optim == 'sgd':
369     optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
370 elif args.optim == 'adam':
371     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
372 elif args.optim == 'adamw':
373     optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
374 else:
375     raise ValueError(f'Unknown optimizer {args.optim}.')
376
377 for k in range(args.nb_epochs):
378
379     model.train()
380
381     nb_train_samples, acc_train_loss = 0, 0.0
382
383     for input in task.batches(split = 'train'):
384         input = input.to(device)
385         output = model(input)
386         loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
387         acc_train_loss += loss.item() * input.size(0)
388         nb_train_samples += input.size(0)
389
390         optimizer.zero_grad()
391         loss.backward()
392         optimizer.step()
393
394     with torch.autograd.no_grad():
395
396         model.eval()
397
398         nb_test_samples, acc_test_loss = 0, 0.0
399
400         for input in task.batches(split = 'test'):
401             input = input.to(device)
402             output = model(input)
403             loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
404             acc_test_loss += loss.item() * input.size(0)
405             nb_test_samples += input.size(0)
406
407         train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
408         test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
409
410         log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
411
412         task.produce_results(k, model)
413
414 ######################################################################