3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = -1)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_nb_colors',
82 type = int, default = 5)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
118 nb_samples, nb_tokens_to_generate, starting_input = None,
119 device = torch.device('cpu')
122 results = torch.zeros(
123 nb_samples, nb_tokens_to_generate,
124 dtype = torch.int64, device = device
127 if starting_input is not None:
128 first = starting_input.size(1)
129 results = torch.cat((starting_input, results), 1)
131 for input in results.split(self.batch_size):
132 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
133 output = model(input)
134 logits = output[:, s]
135 if args.synthesis_sampling:
136 dist = torch.distributions.categorical.Categorical(logits = logits)
137 t_next = dist.sample()
139 t_next = logits.argmax(1)
144 ######################################################################
147 def batches(self, split = 'train'):
150 def vocabulary_size(self):
153 def produce_results(self, n_epoch, model, nb_tokens = 50):
156 ######################################################################
160 class TaskPicoCLVR(Task):
162 def __init__(self, batch_size,
163 height, width, nb_colors = 5,
164 device = torch.device('cpu')):
166 def generate_descr(nb):
167 descr = picoclvr.generate(
169 height = self.height, width = self.width,
170 nb_colors = nb_colors
173 descr = [ s.strip().split(' ') for s in descr ]
174 l = max([ len(s) for s in descr ])
175 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
181 self.batch_size = batch_size
183 nb = args.data_size if args.data_size > 0 else 250000
185 self.train_descr = generate_descr((nb * 4) // 5)
186 self.test_descr = generate_descr((nb * 1) // 5)
188 # Build the tokenizer
190 for d in [ self.train_descr, self.test_descr ]:
192 for t in s: tokens.add(t)
193 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
194 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
196 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
197 self.train_input = torch.tensor(t, device = self.device)
198 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
199 self.test_input = torch.tensor(t, device = self.device)
201 def batches(self, split = 'train'):
202 assert split in { 'train', 'test' }
204 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
207 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
210 def vocabulary_size(self):
211 return len(self.token2id)
213 def generate(self, primer, model, nb_tokens):
214 t_primer = primer.strip().split(' ')
217 for j in range(nb_tokens):
218 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
219 input = torch.tensor(t, device = self.device)
220 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
221 output = model(input)
222 logits = output[0, -1]
223 if args.synthesis_sampling:
224 dist = torch.distributions.categorical.Categorical(logits = logits)
225 t_next = dist.sample()
227 t_next = logits.argmax()
228 t_generated.append(self.id2token[t_next.item()])
230 return ' '.join(t_primer + t_generated)
232 def produce_results(self, n_epoch, model, nb_tokens = None):
233 if nb_tokens is None:
234 nb_tokens = self.height * self.width + 3
239 'red above green <sep> green top <sep> blue right of red <img>',
240 'there is red <sep> there is yellow <sep> there is blue <img>',
241 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
242 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
245 for k in range(nb_per_primer):
246 descr.append(self.generate(primer, model, nb_tokens))
248 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
249 img = torch.cat(img, 0)
250 image_name = f'result_picoclvr_{n_epoch:04d}.png'
251 torchvision.utils.save_image(
253 image_name, nrow = nb_per_primer, pad_value = 0.8
255 log_string(f'wrote {image_name}')
258 x[2] for x in picoclvr.nb_missing_properties(
260 height = self.height, width = self.width
264 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
266 ######################################################################
268 class TaskWiki103(Task):
270 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
271 device = torch.device('cpu')):
273 self.batch_size = batch_size
274 self.len_min = len_min
275 self.len_max = len_max
276 self.min_freq = min_freq
279 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
280 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
283 if args.data_size > 0:
284 train_iter = itertools.islice(train_iter, args.data_size)
287 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
288 yield self.tokenizer(l)
290 self.vocab = torchtext.vocab.build_vocab_from_iterator(
292 specials = [ '<unk>', '<non>' ],
293 min_freq = self.min_freq
296 self.vocab.set_default_index(self.vocab[ '<unk>' ])
298 def tensorize(self, s):
299 a = max(len(x) for x in s)
300 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
302 def yield_batches(self, ds):
305 q = self.tokenizer(l)
306 if len(q) >= self.len_min and len(q) <= self.len_max:
308 if len(s) == self.batch_size:
309 yield self.tensorize(s)
313 yield self.tensorize(s)
315 def batches(self, split = 'train'):
316 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
319 if args.data_size > 0:
320 data_iter = itertools.islice(data_iter, args.data_size)
322 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
324 def vocabulary_size(self):
325 return len(self.vocab)
327 def produce_results(self, n_epoch, model, nb_tokens = 50):
328 file_name = f'result_wiki103_{n_epoch:04d}.txt'
330 with open(file_name, 'w') as outfile:
332 'the cat is hunting a',
333 'paris is the capital',
334 'cars are convenient',
335 'the difference between men and women is',
336 'the object was blue all over and green all over it was',
337 'cherries are red and lemons are',
338 'cherries are sweet and lemons are',
339 'two plus three equals',
342 t_primer = self.tokenizer(primer)
345 for j in range(nb_tokens):
347 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
348 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
349 output = model(input)
350 logits = output[0, -1]
351 if args.synthesis_sampling:
352 dist = torch.distributions.categorical.Categorical(logits = logits)
353 t_next = dist.sample()
355 t_next = logits.argmax()
356 t_generated.append(self.vocab.lookup_token(t_next))
357 if t_generated[-1] == '<non>': break
359 s = ' '.join(t_generated)
361 outfile.write(f'<{primer}> {s}\n')
363 log_string(f'wrote {file_name}')
365 ######################################################################
367 class TaskMNIST(Task):
369 def __init__(self, batch_size, device = torch.device('cpu')):
371 self.batch_size = batch_size
373 def batches(self, split = 'train'):
374 assert split in { 'train', 'test' }
375 data_set = torchvision.datasets.MNIST(
376 root = './data', train = (split == 'train'),
379 data_input = data_set.data.view(-1, 28 * 28).long()
380 if args.data_size >= 0:
381 data_input = data_input[:args.data_size]
382 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
385 def vocabulary_size(self):
388 def produce_results(self, n_epoch, model, nb_samples = 64):
389 results = autoregression(model, nb_samples, 28 * 28, device = self.device)
390 image_name = f'result_mnist_{n_epoch:04d}.png'
391 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392 image_name, nrow = 16, pad_value = 0.8)
393 log_string(f'wrote {image_name}')
395 ######################################################################
397 log_string(f'device {device}')
399 if args.data == 'wiki103':
400 nb_epochs_default = 10
401 task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403 nb_epochs_default = 25
404 task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406 nb_epochs_default = 10
407 task = TaskPicoCLVR(batch_size = args.batch_size,
408 height = args.picoclvr_height,
409 width = args.picoclvr_width,
410 nb_colors = args.picoclvr_nb_colors,
413 raise ValueError(f'Unknown dataset {args.data}.')
415 vocabulary_size = task.vocabulary_size()
417 log_string(f'vocabulary_size {vocabulary_size}')
419 ##############################
422 vocabulary_size = vocabulary_size,
423 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
432 ######################################################################
434 if args.optim == 'sgd':
435 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
441 raise ValueError(f'Unknown optimizer {args.optim}.')
443 ######################################################################
445 nb_epochs_finished = 0
447 if args.no_checkpoint:
448 log_string(f'Not trying to load checkpoint.')
452 checkpoint = torch.load(args.checkpoint_name, map_location = device)
453 nb_epochs_finished = checkpoint['nb_epochs_finished']
454 model.load_state_dict(checkpoint['model_state'])
455 optimizer.load_state_dict(checkpoint['optimizer_state'])
456 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
458 except FileNotFoundError:
459 log_string('Starting from scratch.')
462 log_string('Error when loading the checkpoint.')
465 ######################################################################
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
470 for input in task.batches(split = 'train'):
471 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'Train set perplexity {train_set_perplexity}')
477 for k in range(nb_epochs_finished, nb_epochs):
481 nb_train_samples, acc_train_loss = 0, 0.0
483 for input in task.batches(split = 'train'):
484 input = input.to(device)
485 output = model(input)
486 loss = F.cross_entropy(output.transpose(1, 2), input)
487 acc_train_loss += loss.item() * input.size(0)
488 nb_train_samples += input.size(0)
490 optimizer.zero_grad()
494 with torch.autograd.no_grad():
498 nb_test_samples, acc_test_loss = 0, 0.0
500 for input in task.batches(split = 'test'):
501 input = input.to(device)
502 output = model(input)
503 loss = F.cross_entropy(output.transpose(1, 2), input)
504 acc_test_loss += loss.item() * input.size(0)
505 nb_test_samples += input.size(0)
507 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
510 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
512 task.produce_results(k, model)
515 'nb_epochs_finished': k + 1,
516 'model_state': model.state_dict(),
517 'optimizer_state': optimizer.state_dict()
520 torch.save(checkpoint, args.checkpoint_name)
522 ######################################################################