3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_many_colors',
82 action='store_true', default = False)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
117 def batches(self, split = 'train'):
120 def vocabulary_size(self):
123 def produce_results(self, n_epoch, model, nb_tokens = 50):
126 ######################################################################
130 class TaskPicoCLVR(Task):
132 def __init__(self, batch_size,
133 height, width, many_colors = False,
134 device = torch.device('cpu')):
136 def generate_descr(nb):
137 descr = picoclvr.generate(
139 height = self.height, width = self.width,
140 many_colors = many_colors
143 descr = [ s.strip().split(' ') for s in descr ]
144 l = max([ len(s) for s in descr ])
145 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
151 self.batch_size = batch_size
153 nb = args.data_size if args.data_size > 0 else 250000
155 self.train_descr = generate_descr((nb * 4) // 5)
156 self.test_descr = generate_descr((nb * 1) // 5)
158 # Build the tokenizer
160 for d in [ self.train_descr, self.test_descr ]:
162 for t in s: tokens.add(t)
163 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
164 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
166 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
167 self.train_input = torch.tensor(t, device = self.device)
168 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
169 self.test_input = torch.tensor(t, device = self.device)
171 def batches(self, split = 'train'):
172 assert split in { 'train', 'test' }
174 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
177 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
180 def vocabulary_size(self):
181 return len(self.token2id)
183 def generate(self, primer, model, nb_tokens):
184 t_primer = primer.strip().split(' ')
187 for j in range(nb_tokens):
188 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
189 input = torch.tensor(t, device = self.device)
190 output = model(input)
191 logits = output[0, -1]
192 if args.synthesis_sampling:
193 dist = torch.distributions.categorical.Categorical(logits = logits)
197 t_generated.append(self.id2token[t.item()])
199 return ' '.join(t_primer + t_generated)
201 def produce_results(self, n_epoch, model, nb_tokens = None):
202 if nb_tokens is None:
203 nb_tokens = self.height * self.width + 3
208 'red above green <sep> green top <sep> blue right of red <img>',
209 'there is red <sep> there is yellow <sep> there is blue <img>',
210 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
211 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
214 for k in range(nb_per_primer):
215 descr.append(self.generate(primer, model, nb_tokens))
217 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
218 img = torch.cat(img, 0)
219 image_name = f'result_picoclvr_{n_epoch:04d}.png'
220 torchvision.utils.save_image(
222 image_name, nrow = nb_per_primer, pad_value = 0.8
224 log_string(f'wrote {image_name}')
227 x[2] for x in picoclvr.nb_missing_properties(
229 height = self.height, width = self.width
233 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
235 ######################################################################
237 class TaskWiki103(Task):
239 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
240 device = torch.device('cpu')):
242 self.batch_size = batch_size
243 self.len_min = len_min
244 self.len_max = len_max
245 self.min_freq = min_freq
248 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
249 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
252 if args.data_size > 0:
253 train_iter = itertools.islice(train_iter, args.data_size)
256 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
257 yield self.tokenizer(l)
259 self.vocab = torchtext.vocab.build_vocab_from_iterator(
261 specials = [ '<unk>', '<non>' ],
262 min_freq = self.min_freq
265 self.vocab.set_default_index(self.vocab[ '<unk>' ])
267 def tensorize(self, s):
268 a = max(len(x) for x in s)
269 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
271 def yield_batches(self, ds):
274 q = self.tokenizer(l)
275 if len(q) >= self.len_min and len(q) <= self.len_max:
277 if len(s) == self.batch_size:
278 yield self.tensorize(s)
282 yield self.tensorize(s)
284 def batches(self, split = 'train'):
285 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
288 if args.data_size > 0:
289 data_iter = itertools.islice(data_iter, args.data_size)
291 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
293 def vocabulary_size(self):
294 return len(self.vocab)
296 def produce_results(self, n_epoch, model, nb_tokens = 50):
297 file_name = f'result_wiki103_{n_epoch:04d}.txt'
299 with open(file_name, 'w') as outfile:
301 'the cat is hunting a',
302 'paris is the capital',
303 'cars are convenient',
304 'the difference between men and women is',
305 'the object was blue all over and green all over it was',
306 'cherries are red and lemons are',
307 'cherries are sweet and lemons are',
308 'two plus three equals',
311 t_primer = self.tokenizer(primer)
314 for j in range(nb_tokens):
316 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
317 output = model(input)
318 logits = output[0, -1]
319 if args.synthesis_sampling:
320 dist = torch.distributions.categorical.Categorical(logits = logits)
324 t_generated.append(self.vocab.lookup_token(t))
325 if t_generated[-1] == '<non>': break
327 s = ' '.join(t_generated)
329 outfile.write(f'<{primer}> {s}\n')
331 log_string(f'wrote {file_name}')
333 ######################################################################
335 class TaskMNIST(Task):
337 def __init__(self, batch_size, device = torch.device('cpu')):
339 self.batch_size = batch_size
341 def batches(self, split = 'train'):
342 assert split in { 'train', 'test' }
343 data_set = torchvision.datasets.MNIST(
344 root = './data', train = (split == 'train'),
347 data_input = data_set.data.view(-1, 28 * 28).long()
348 if args.data_size >= 0:
349 data_input = data_input[:args.data_size]
350 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
353 def vocabulary_size(self):
356 def produce_results(self, n_epoch, model, nb_samples = 64):
357 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
358 for input in results.split(self.batch_size):
359 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
360 output = model(input)
361 logits = output[:, s]
362 if args.synthesis_sampling:
363 dist = torch.distributions.categorical.Categorical(logits = logits)
369 image_name = f'result_mnist_{n_epoch:04d}.png'
370 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
371 image_name, nrow = 16, pad_value = 0.8)
372 log_string(f'wrote {image_name}')
374 ######################################################################
376 def check_causality(model):
378 input = torch.rand(1, 5, dim_model).requires_grad_()
380 a = torch.zeros(output.size(1), input.size(1))
381 for k in range(output.size(1)):
382 for d in range(output.size(2)):
383 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
384 a[k] += g.squeeze(0).pow(2).sum(1)
387 ######################################################################
389 log_string(f'device {device}')
391 if args.data == 'wiki103':
392 task = TaskWiki103(batch_size = args.batch_size, device = device)
393 elif args.data == 'mnist':
394 task = TaskMNIST(batch_size = args.batch_size, device = device)
395 elif args.data == 'picoclvr':
396 task = TaskPicoCLVR(batch_size = args.batch_size,
397 height = args.picoclvr_height,
398 width = args.picoclvr_width,
399 many_colors = args.picoclvr_many_colors,
402 raise ValueError(f'Unknown dataset {args.data}.')
404 vocabulary_size = task.vocabulary_size()
406 log_string(f'vocabulary_size {vocabulary_size}')
408 ##############################
411 vocabulary_size = vocabulary_size,
412 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
413 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
418 nb_parameters = sum(p.numel() for p in model.parameters())
419 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
421 ######################################################################
423 if args.optim == 'sgd':
424 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
425 elif args.optim == 'adam':
426 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adamw':
428 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
430 raise ValueError(f'Unknown optimizer {args.optim}.')
432 ######################################################################
434 nb_epochs_finished = 0
436 if args.no_checkpoint:
437 log_string(f'Not trying to load checkpoint.')
441 checkpoint = torch.load(args.checkpoint_name, map_location = device)
442 nb_epochs_finished = checkpoint['nb_epochs_finished']
443 model.load_state_dict(checkpoint['model_state'])
444 optimizer.load_state_dict(checkpoint['optimizer_state'])
445 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
447 except FileNotFoundError:
448 log_string('Starting from scratch.')
451 log_string('Error when loading the checkpoint.')
454 ######################################################################
456 for k in range(nb_epochs_finished, args.nb_epochs):
460 nb_train_samples, acc_train_loss = 0, 0.0
462 for input in task.batches(split = 'train'):
463 input = input.to(device)
464 output = model(input)
465 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
466 acc_train_loss += loss.item() * input.size(0)
467 nb_train_samples += input.size(0)
469 optimizer.zero_grad()
473 with torch.autograd.no_grad():
477 nb_test_samples, acc_test_loss = 0, 0.0
479 for input in task.batches(split = 'test'):
480 input = input.to(device)
481 output = model(input)
482 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
483 acc_test_loss += loss.item() * input.size(0)
484 nb_test_samples += input.size(0)
486 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
487 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
489 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
491 task.produce_results(k, model)
494 'nb_epochs_finished': k + 1,
495 'model_state': model.state_dict(),
496 'optimizer_state': optimizer.state_dict()
499 torch.save(checkpoint, args.checkpoint_name)
501 ######################################################################