From 943a440a83b98de60bad767a9ad09f63b5088514 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 17 Dec 2022 12:41:39 +0100 Subject: [PATCH] Initial commit. --- main.py | 630 +++++++++++++++++++++++++++++++++++++++++++++++++ mygpt.py | 292 +++++++++++++++++++++++ picoclvr.py | 511 +++++++++++++++++++++++++++++++++++++++ tensorstack.py | 62 +++++ 4 files changed, 1495 insertions(+) create mode 100755 main.py create mode 100755 mygpt.py create mode 100755 picoclvr.py create mode 100755 tensorstack.py diff --git a/main.py b/main.py new file mode 100755 index 0000000..6d9f69d --- /dev/null +++ b/main.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, argparse, time, tqdm, itertools, os + +import torch, torchvision +from torch import nn +from torch.nn import functional as F + +import mygpt, tensorstack + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser( + description="An implementation of GPT with cache to solve a toy geometric reasonning task." +) + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--result_dir", type=str, default="results_default") + +parser.add_argument("--seed", type=int, default=0) + +parser.add_argument("--nb_epochs", type=int, default=25) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--data_size", type=int, default=-1) + +parser.add_argument("--optim", type=str, default="adam") + +parser.add_argument("--learning_rate", type=float, default=1e-3) + +parser.add_argument( + "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6" +) + +parser.add_argument("--dim_model", type=int, default=512) + +parser.add_argument("--dim_keys", type=int, default=64) + +parser.add_argument("--dim_hidden", type=int, default=2048) + +parser.add_argument("--nb_heads", type=int, default=8) + +parser.add_argument("--nb_blocks", type=int, default=12) + +parser.add_argument("--dropout", type=float, default=0.1) + +parser.add_argument("--nb_oneshot_blocks", type=int, default=-1) + +parser.add_argument("--deterministic_synthesis", action="store_true", default=False) + +parser.add_argument("--no_checkpoint", action="store_true", default=False) + +parser.add_argument("--overwrite_results", action="store_true", default=False) + +parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") + +############################## +# picoclvr options + +parser.add_argument("--nb_colors", type=int, default=5) + +parser.add_argument("--height", type=int, default=12) + +parser.add_argument("--width", type=int, default=16) + +parser.add_argument("--prune_properties", type=str, default="none") + +###################################################################### + +args = parser.parse_args() + +assert args.prune_properties in {"none", "train+eval", "eval"} + +try: + os.mkdir(args.result_dir) +except FileExistsError: + if not args.overwrite_results: + print(f"result directory {args.result_dir} already exists") + exit(1) + +log_file = open(os.path.join(args.result_dir, args.log_filename), "w") + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + +###################################################################### + + +def masked_inplace_autoregression( + model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") +): + + for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): + i = (ar_mask.sum(0) > 0).nonzero() + if i.min() > 0: + model( + mygpt.BracketedSequence(input, 0, i.min()) + ) # Needed to initialize the model's cache + for s in range(i.min(), i.max() + 1): + output = model(mygpt.BracketedSequence(input, s, 1)).x + logits = output[:, s] + if forbidden_tokens is not None: + logits = logits.masked_fill(forbidden_tokens, float("-inf")) + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: + dist = torch.distributions.categorical.Categorical(logits=logits) + t_next = dist.sample() + input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + + +###################################################################### + + +class Task: + def batches(self, split="train"): + pass + + def vocabulary_size(self): + pass + + def produce_results(self, n_epoch, model): + pass + + +###################################################################### + +import picoclvr + + +class TaskPicoCLVR(Task): + + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + token_descr = [s + [""] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in token_descr] + return torch.tensor(id_descr, device=self.device) + + # Make a list of strings from a tensor + def detensorize(self, x): + return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + + # trim all the tensors in the tuple z to remove as much token from + # left and right in the first tensor. If z is a tuple, all its + # elements are trimed according to the triming for the first + def trim(self, z, token=""): + n = self.token2id[token] + if type(z) == tuple: + x = z[0] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return tuple([t[:, a:b] for t in z]) + else: + i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return z[:, a:b] + + ###################### + # Not the cleanest part of the code + + # Extract the last image of each sequence, from the last + # included, and set to all the tokens from the beginning of + # that image to the end + def excise_last_image(self, input): + t_img, t_nul = self.token2id[""], self.token2id[""] + nb_img_tokens = self.height * self.width + 1 + + input = input.clone() + t = (input == t_img).long() + tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long() + i = (t * tail_masks).nonzero(as_tuple=True) + j = ( + i[0][:, None], + i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], + ) + images = self.trim(input[j]) + input[j] = t_nul + loss_masks = 1 - tail_masks + input, loss_masks = self.trim((input, loss_masks)) + return input, loss_masks, images + + def add_true_image(self, input, images, loss_masks): + t_nul = self.token2id[""] + nb_img_tokens = self.height * self.width + 1 + input = F.pad(input, (0, nb_img_tokens), value=t_nul) + loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) + t = (input == t_nul).long() + i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) + j = ( + i[0][:, None], + i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], + ) + input[j] = images + loss_masks[j] = 1 + input, loss_masks = self.trim((input, loss_masks)) + return input, loss_masks + + def add_generated_image(self, input, loss_masks, model): + t_img, t_nul = self.token2id[""], self.token2id[""] + nb_img_tokens = self.height * self.width + 1 + + input = F.pad(input, (0, nb_img_tokens), value=t_nul) + loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) + t = (input == t_nul).long() + i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) + input[i] = t_img + + j = ( + i[0][:, None], + i[1][:, None] + + 1 + + torch.arange(nb_img_tokens - 1, device=input.device)[None, :], + ) + ar_masks = input.new_zeros(input.size(), dtype=torch.int64) + ar_masks[j] = 1 + forbidden_tokens = ( + torch.arange(self.vocabulary_size(), device=input.device) == t_nul + ) + with torch.autograd.no_grad(): + t = model.training + model.eval() + masked_inplace_autoregression( + model, + self.batch_size, + input, + ar_masks, + forbidden_tokens, + device=self.device, + ) + model.train(t) + + input, loss_masks = self.trim((input, loss_masks)) + + return input, loss_masks + + ###################### + + def __init__( + self, + batch_size, + height, + width, + nb_colors=5, + device=torch.device("cpu"), + pruner_train=None, + pruner_eval=None, + ): + def generate_descr(nb, cache_suffix, pruner): + return picoclvr.generate( + nb, + height=self.height, + width=self.width, + nb_colors=nb_colors, + pruner=pruner, + ) + + self.height = height + self.width = width + self.batch_size = batch_size + self.device = device + nb = args.data_size if args.data_size > 0 else 250000 + self.pruner_train = pruner_train + self.pruner_eval = pruner_eval + + param = { + "nb": nb, + "height": height, + "width": width, + "nb_colors": nb_colors, + "batch_size": batch_size, + "rng_state": list(torch.get_rng_state()), + } + + log_string(f"generating {nb} samples (can take some time)") + self.train_descr = generate_descr( + (nb * 4) // 5, "train", pruner=self.pruner_train + ) + self.test_descr = generate_descr((nb * 1) // 5, "test", pruner=None) + + # Build the tokenizer + tokens = {"", ""} + for d in [self.train_descr, self.test_descr]: + for s in d: + for t in s.strip().split(" "): + tokens.add(t) + # make this set a sorted list to get the same tensors given + # the same descr + tokens = list(tokens) + tokens.sort() + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + ): + yield self.trim(batch) + + def vocabulary_size(self): + return len(self.token2id) + + def compute_missing_properties(self, n_epoch, model, pruner=None): + + acc_nb_requested_properties = [] + acc_nb_missing_properties = [] + acc_nb_results = 0 + + for input in tqdm.tqdm( + self.test_input.split(self.batch_size), + dynamic_ncols=True, + desc=f"test-properties", + ): + tape, loss_masks, _ = self.excise_last_image(input) + tape, loss_masks = self.add_generated_image(tape, loss_masks, model) + result_descr = self.detensorize(tape) + np = picoclvr.nb_properties( + result_descr, + height=self.height, + width=self.width, + pruner=pruner, + ) + nb_requested_properties, _, nb_missing_properties = zip(*np) + acc_nb_requested_properties += nb_requested_properties + acc_nb_missing_properties += nb_missing_properties + acc_nb_results += len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "" if pruner is None else "pruned_" + log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + log_string( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + log_string( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + ###################################################################### + + def produce_results(self, n_epoch, model): + + self.compute_missing_properties(n_epoch, model) + + if self.pruner_eval is not None: + self.compute_missing_properties(n_epoch, model, self.pruner_eval) + + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [] + nb_per_primer = 8 + primer = [] + + for primer_descr in [ + "red above green green top blue right of red", + "there is red there is yellow there is blue", + "red below yellow yellow below green green below blue red right yellow left green right blue left", + "green bottom yellow bottom green left of blue yellow right of blue blue top", + ]: + primer += [primer_descr] * nb_per_primer + + tape = self.tensorize(primer) + loss_masks = 1 - (tape == self.token2id[""]).long() + tape, loss_masks = self.add_generated_image(tape, loss_masks, model) + result_descr = self.detensorize(tape) + + np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) + + acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np) + acc_nb_results = len(result_descr) + + nb_requested_properties = sum(acc_nb_requested_properties) + nb_missing_properties = sum(acc_nb_missing_properties) + + prefix = "demo_" + log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + log_string( + f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + ) + log_string( + f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + ) + + img = picoclvr.descr2img( + result_descr, [0], height=self.height, width=self.width + ) + + if img.dim() == 5: + if img.size(1) == 1: + img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64) + else: + img = torch.cat( + [ + torchvision.utils.make_grid(x, padding=1, pad_value=64)[None] + for x in img + ], + 0, + ) + + image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + torchvision.utils.save_image( + img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 + ) + log_string(f"wrote {image_name}") + + +###################################################################### + +log_string(f"device {device}") + + +def pruner_horizontal_green(p): + return not ("green" in p and ("left" in p or "right" in p)) + + +task = TaskPicoCLVR( + batch_size=args.batch_size, + height=args.height, + width=args.width, + nb_colors=args.nb_colors, + device=device, + pruner_train=pruner_horizontal_green + if args.prune_properties in {"train+eval"} + else None, + pruner_eval=(lambda p: not pruner_horizontal_green(p)) + if args.prune_properties in {"train+eval", "eval"} + else None, +) + +vocabulary_size = task.vocabulary_size() + +log_string(f"vocabulary_size {vocabulary_size}") + +############################## + +model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, +) + +model.to(device) + +nb_parameters = sum(p.numel() for p in model.parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") + +###################################################################### + +nb_epochs_finished = 0 + +if args.no_checkpoint: + log_string(f"not trying to load checkpoint.") + +else: + try: + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + checkpoint = torch.load(checkpoint_name) + nb_epochs_finished = checkpoint["nb_epochs_finished"] + model.load_state_dict(checkpoint["model_state"]) + torch.set_rng_state(checkpoint["rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) + + log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.") + + except FileNotFoundError: + log_string("starting from scratch.") + + except: + log_string("error when loading the checkpoint.") + exit(1) + +###################################################################### + +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split="train"): + token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +############################## + +if args.learning_rate_schedule == "cos": + learning_rate_schedule = {} + for n_epoch in range(args.nb_epochs): + u = n_epoch / args.nb_epochs * math.pi + learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u)) +else: + u = { + int(k): float(v) + for k, v in [ + tuple(x.split(":")) for x in args.learning_rate_schedule.split(",") + ] + } + + learning_rate_schedule = {} + learning_rate = args.learning_rate + for n_epoch in range(args.nb_epochs): + if n_epoch in u: + learning_rate = u[n_epoch] + learning_rate_schedule[n_epoch] = learning_rate + +log_string(f"learning_rate_schedule {learning_rate_schedule}") + +############################## + +nb_samples_seen = 0 + +if nb_epochs_finished >= nb_epochs: + task.produce_results(nb_epochs_finished, model) + +for n_epoch in range(nb_epochs_finished, nb_epochs): + + learning_rate = learning_rate_schedule[n_epoch] + + log_string(f"learning_rate {learning_rate}") + + if args.optim == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + elif args.optim == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + elif args.optim == "adamw": + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + else: + raise ValueError(f"Unknown optimizer {args.optim}.") + + model.train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in task.batches(split="train"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.autograd.no_grad(): + + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in task.batches(split="test"): + input = input.to(device) + + # input, loss_masks, true_images = task.excise_last_image(input) + # input, loss_masks = task.add_true_image(input, true_images, loss_masks) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) + + task.produce_results(n_epoch, model) + + checkpoint = { + "nb_epochs_finished": n_epoch + 1, + "model_state": model.state_dict(), + "rng_state": torch.get_rng_state(), + } + + if torch.cuda.is_available(): + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() + + checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name) + torch.save(checkpoint, checkpoint_name) + log_string(f"saved checkpoint {checkpoint_name}") + +###################################################################### diff --git a/mygpt.py b/mygpt.py new file mode 100755 index 0000000..0ed7eb0 --- /dev/null +++ b/mygpt.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math + +import torch + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + bs.x = bs.x + self.f(bs).x + return bs + + +###################################################################### + +# A BracketedSequence is a BxTx... tensor with a first and a nb time +# steps to compute. + +# Modules able to process it expect that they will have to process a +# first bracket starting at t=0, followed by a succession of brackets +# that move forward in time, do not overlap, and cover the axis T with +# no holes. +# +# Although it is more general, for a classical prompt-conditioned +# auto-regressive process it will be a first bracket starting at 0 and +# of arbitrary length for the "prompt", followed by brackets of length +# 1 for the successive tokens. +# +# Modules able to process brackets may implement a cache that is +# resetted when the input bracket starts at t=0 + + +class BracketedSequence: + def __init__(self, x, first=None, nb=None): + self.x = x + self.first = 0 if first is None else first + self.nb = x.size(1) if nb is None else nb + + def slice(self): + return self.x[:, self.first : self.first + self.nb] + + +###################################################################### + + +class CacheWrapper(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + if bs.first == 0: + y = self.f(bs.slice()) + self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:])) + self.cache_y[:, bs.first : bs.first + bs.nb] = y + else: + self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) + + bs.x = self.cache_y + + return bs + + +############################## + + +class AddPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) + + def forward(self, bs): + if bs.first == 0: + t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ + :, None + ] + j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + None, : + ] + k = j % 2 + self.pe = torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + self.cache_y = bs.x.new(bs.x.size()) + + self.cache_y[:, bs.first : bs.first + bs.nb] = ( + bs.slice() + self.pe[bs.first : bs.first + bs.nb] + ) + + bs.x = self.cache_y + + return bs + + +############################## + + +class QKVAttention(nn.Module): + def __init__( + self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout + + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def forward(self, bs_q, x_kv=None): + x_q = bs_q.x + if x_kv is None: + x_kv = x_q + + if bs_q.first == 0: + self.cache_k = x_q.new_zeros( + x_q.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1) + ) + self.cache_v = x_q.new_zeros( + x_q.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1) + ) + self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) + + q = torch.einsum( + "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q + ) + self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( + "ntc,hdc->nhtd", x_kv[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k + ) + self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( + "ntc,hdc->nhtd", x_kv[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v + ) + + a = torch.einsum( + "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] + ) / math.sqrt(self.w_q.size(1)) + + if self.causal: + if bs_q.first == 0: + self.cache_attzero = ( + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + < torch.arange(x_kv.size(1), device=q.device)[None, None, None, :] + ) + a = a.masked_fill( + self.cache_attzero[ + :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb + ], + float("-inf"), + ) + + a = a.softmax(dim=3) + a = F.dropout(a, self.attention_dropout, self.training) + + y = torch.einsum( + "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb] + ).flatten(2) + + self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o + + bs_q.x = self.cache_y + + return bs_q + + +############################## + + +class MyGPT(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + causal=False, + dropout=0.0, + len_max=1e5, + ): + + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + AddPositionalEncoding(len_max), + ) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + ), + ), + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, bs): + bs.x = F.pad(bs.x, (1, -1)) + bs = self.embedding(bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + + +###################################################################### + +if __name__ == "__main__": + + print("Basic check.") + + vocabulary_size = 10 + x = torch.randint(vocabulary_size, (9, 7)) + + model = MyGPT( + vocabulary_size=vocabulary_size, + dim_model=18, + dim_keys=50, + dim_hidden=100, + nb_heads=2, + nb_blocks=1, + dropout=0.1, + ) + + model.eval() + + y1 = model(BracketedSequence(x)).x + + y2 = torch.randn_like(y1) + for s in range(x.size(1)): + z = model(BracketedSequence(x, s, 1)) + y2[:, s] = z.x[:, s] + + # print(y1.max(dim = 2).values) + # print(y2.max(dim = 2).values) + print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") + +###################################################################### diff --git a/picoclvr.py b/picoclvr.py new file mode 100755 index 0000000..94c0f88 --- /dev/null +++ b/picoclvr.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision +import torch.nn.functional as F + +colors = [ + [255, 255, 255], + [255, 0, 0], + [0, 128, 0], + [0, 0, 255], + [255, 255, 0], + [0, 0, 0], + [128, 0, 0], + [139, 0, 0], + [165, 42, 42], + [178, 34, 34], + [220, 20, 60], + [255, 99, 71], + [255, 127, 80], + [205, 92, 92], + [240, 128, 128], + [233, 150, 122], + [250, 128, 114], + [255, 160, 122], + [255, 69, 0], + [255, 140, 0], + [255, 165, 0], + [255, 215, 0], + [184, 134, 11], + [218, 165, 32], + [238, 232, 170], + [189, 183, 107], + [240, 230, 140], + [128, 128, 0], + [154, 205, 50], + [85, 107, 47], + [107, 142, 35], + [124, 252, 0], + [127, 255, 0], + [173, 255, 47], + [0, 100, 0], + [34, 139, 34], + [0, 255, 0], + [50, 205, 50], + [144, 238, 144], + [152, 251, 152], + [143, 188, 143], + [0, 250, 154], + [0, 255, 127], + [46, 139, 87], + [102, 205, 170], + [60, 179, 113], + [32, 178, 170], + [47, 79, 79], + [0, 128, 128], + [0, 139, 139], + [0, 255, 255], + [0, 255, 255], + [224, 255, 255], + [0, 206, 209], + [64, 224, 208], + [72, 209, 204], + [175, 238, 238], + [127, 255, 212], + [176, 224, 230], + [95, 158, 160], + [70, 130, 180], + [100, 149, 237], + [0, 191, 255], + [30, 144, 255], + [173, 216, 230], + [135, 206, 235], + [135, 206, 250], + [25, 25, 112], + [0, 0, 128], + [0, 0, 139], + [0, 0, 205], + [65, 105, 225], + [138, 43, 226], + [75, 0, 130], + [72, 61, 139], + [106, 90, 205], + [123, 104, 238], + [147, 112, 219], + [139, 0, 139], + [148, 0, 211], + [153, 50, 204], + [186, 85, 211], + [128, 0, 128], + [216, 191, 216], + [221, 160, 221], + [238, 130, 238], + [255, 0, 255], + [218, 112, 214], + [199, 21, 133], + [219, 112, 147], + [255, 20, 147], + [255, 105, 180], + [255, 182, 193], + [255, 192, 203], + [250, 235, 215], + [245, 245, 220], + [255, 228, 196], + [255, 235, 205], + [245, 222, 179], + [255, 248, 220], + [255, 250, 205], + [250, 250, 210], + [255, 255, 224], + [139, 69, 19], + [160, 82, 45], + [210, 105, 30], + [205, 133, 63], + [244, 164, 96], + [222, 184, 135], + [210, 180, 140], + [188, 143, 143], + [255, 228, 181], + [255, 222, 173], + [255, 218, 185], + [255, 228, 225], + [255, 240, 245], + [250, 240, 230], + [253, 245, 230], + [255, 239, 213], + [255, 245, 238], + [245, 255, 250], + [112, 128, 144], + [119, 136, 153], + [176, 196, 222], + [230, 230, 250], + [255, 250, 240], + [240, 248, 255], + [248, 248, 255], + [240, 255, 240], + [255, 255, 240], + [240, 255, 255], + [255, 250, 250], + [192, 192, 192], + [220, 220, 220], + [245, 245, 245], +] + +color_names = [ + "white", + "red", + "green", + "blue", + "yellow", + "black", + "maroon", + "dark_red", + "brown", + "firebrick", + "crimson", + "tomato", + "coral", + "indian_red", + "light_coral", + "dark_salmon", + "salmon", + "light_salmon", + "orange_red", + "dark_orange", + "orange", + "gold", + "dark_golden_rod", + "golden_rod", + "pale_golden_rod", + "dark_khaki", + "khaki", + "olive", + "yellow_green", + "dark_olive_green", + "olive_drab", + "lawn_green", + "chartreuse", + "green_yellow", + "dark_green", + "forest_green", + "lime", + "lime_green", + "light_green", + "pale_green", + "dark_sea_green", + "medium_spring_green", + "spring_green", + "sea_green", + "medium_aqua_marine", + "medium_sea_green", + "light_sea_green", + "dark_slate_gray", + "teal", + "dark_cyan", + "aqua", + "cyan", + "light_cyan", + "dark_turquoise", + "turquoise", + "medium_turquoise", + "pale_turquoise", + "aqua_marine", + "powder_blue", + "cadet_blue", + "steel_blue", + "corn_flower_blue", + "deep_sky_blue", + "dodger_blue", + "light_blue", + "sky_blue", + "light_sky_blue", + "midnight_blue", + "navy", + "dark_blue", + "medium_blue", + "royal_blue", + "blue_violet", + "indigo", + "dark_slate_blue", + "slate_blue", + "medium_slate_blue", + "medium_purple", + "dark_magenta", + "dark_violet", + "dark_orchid", + "medium_orchid", + "purple", + "thistle", + "plum", + "violet", + "magenta", + "orchid", + "medium_violet_red", + "pale_violet_red", + "deep_pink", + "hot_pink", + "light_pink", + "pink", + "antique_white", + "beige", + "bisque", + "blanched_almond", + "wheat", + "corn_silk", + "lemon_chiffon", + "light_golden_rod_yellow", + "light_yellow", + "saddle_brown", + "sienna", + "chocolate", + "peru", + "sandy_brown", + "burly_wood", + "tan", + "rosy_brown", + "moccasin", + "navajo_white", + "peach_puff", + "misty_rose", + "lavender_blush", + "linen", + "old_lace", + "papaya_whip", + "sea_shell", + "mint_cream", + "slate_gray", + "light_slate_gray", + "light_steel_blue", + "lavender", + "floral_white", + "alice_blue", + "ghost_white", + "honeydew", + "ivory", + "azure", + "snow", + "silver", + "gainsboro", + "white_smoke", +] + +color_id = dict([(n, k) for k, n in enumerate(color_names)]) +color_tokens = dict([(n, c) for n, c in zip(color_names, colors)]) + +###################################################################### + + +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [] + + for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]: + s += [f"there is {c_r}"] + + if square_i[r] >= height - height // 3: + s += [f"{c_r} bottom"] + if square_i[r] < height // 3: + s += [f"{c_r} top"] + if square_j[r] >= width - width // 3: + s += [f"{c_r} right"] + if square_j[r] < width // 3: + s += [f"{c_r} left"] + + for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]: + if square_i[r] > square_i[t]: + s += [f"{c_r} below {c_t}"] + if square_i[r] < square_i[t]: + s += [f"{c_r} above {c_t}"] + if square_j[r] > square_j[t]: + s += [f"{c_r} right of {c_t}"] + if square_j[r] < square_j[t]: + s += [f"{c_r} left of {c_t}"] + + return s + + +###################################################################### + +# Generates sequences + + +def generate( + nb, + height, + width, + max_nb_squares=5, + max_nb_properties=10, + nb_colors=5, + pruner=None, +): + + assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 + + descr = [] + + for n in range(nb): + + nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + square_position = torch.randperm(height * width)[:nb_squares] + + # color 0 is white and reserved for the background + square_c = torch.randperm(nb_colors)[:nb_squares] + 1 + square_i = square_position.div(width, rounding_mode="floor") + square_j = square_position % width + + img = [0] * height * width + for k in range(nb_squares): + img[square_position[k]] = square_c[k] + + # generates all the true properties + + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) + + if pruner is not None: + s = list(filter(pruner, s)) + + # picks at most max_nb_properties at random + + nb_properties = torch.randint(max_nb_properties, (1,)) + 1 + s = ( + " ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]]) + + " " + + " ".join([f"{color_names[n]}" for n in img]) + ) + + descr += [s] + + return descr + + +###################################################################### + +# Extracts the image after in descr as a 1x3xHxW tensor + + +def descr2img(descr, n, height, width): + + if type(descr) == list: + return torch.cat([descr2img(d, n, height, width) for d in descr], 0) + + if type(n) == list: + return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze( + 0 + ) + + def token2color(t): + try: + return color_tokens[t] + except KeyError: + return [128, 128, 128] + + d = descr.split("") + d = d[n + 1] if len(d) > n + 1 else "" + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0) + img = img.reshape(1, 3, height, width) + + return img + + +###################################################################### + +# Returns all the properties of the image after in descr + + +def descr2properties(descr, height, width): + + if type(descr) == list: + return [descr2properties(d, height, width) for d in descr] + + d = descr.split("") + d = d[-1] if len(d) > 1 else "" + d = d.strip().split(" ")[: height * width] + if len(d) != height * width: + return [] + + seen = {} + for k, x in enumerate(d): + if x != color_names[0]: + if x in color_tokens: + if x in seen: + return [] + else: + return [] + seen[x] = (color_id[x], k // width, k % width) + + square_infos = tuple(zip(*seen.values())) + + if square_infos: + square_c = torch.tensor(square_infos[0]) + square_i = torch.tensor(square_infos[1]) + square_j = torch.tensor(square_infos[2]) + else: + square_c = torch.tensor([]) + square_i = torch.tensor([]) + square_j = torch.tensor([]) + + s = all_properties(height, width, len(seen), square_i, square_j, square_c) + + return s + + +###################################################################### + +# Returns a triplet composed of (1) the total number of properties +# before in descr, (2) the total number of properties the image +# after verifies, and (3) the number of properties in (1) not in +# (2) + + +def nb_properties(descr, height, width, pruner=None): + + if type(descr) == list: + return [nb_properties(d, height, width, pruner) for d in descr] + + d = descr.split("", 1) + if len(d) == 0: + return 0 + d = d[0].strip().split("") + d = [x.strip() for x in d] + + all_properties = set(descr2properties(descr, height, width)) + + if pruner is None: + requested_properties = set(d) + else: + requested_properties = set(filter(pruner, d)) + + missing_properties = requested_properties - all_properties + + return (len(requested_properties), len(all_properties), len(missing_properties)) + + +###################################################################### + +if __name__ == "__main__": + for n in range(16): + descr = generate(nb=1, height=12, width=16) + + print(nb_properties(descr, height=12, width=16)) + + with open(f"picoclvr_example_{n:02d}.txt", "w") as f: + for d in descr: + f.write(f"{d}\n\n") + + img = descr2img(descr, n=0, height=12, width=16) + if img.size(0) == 1: + img = F.pad(img, (1, 1, 1, 1), value=64) + + torchvision.utils.save_image( + img / 255.0, + f"picoclvr_example_{n:02d}.png", + padding=1, + nrow=4, + pad_value=0.8, + ) + + import time + + start_time = time.perf_counter() + descr = generate(nb=1000, height=12, width=16) + end_time = time.perf_counter() + print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + +###################################################################### diff --git a/tensorstack.py b/tensorstack.py new file mode 100755 index 0000000..3218be1 --- /dev/null +++ b/tensorstack.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +from torch import Tensor + +import sys + + +def exception_hook(exc_type, exc_value, tb): + r"""Hacks the call stack message to show all the local variables in + case of RuntimeError or ValueError, and prints tensors as shape, + dtype and device. + + """ + + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" + + while tb: + print("--------------------------------------------------\n") + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + print(f' File "{filename}", line {line_no}, in {name}') + print(open(filename, "r").readlines()[line_no - 1]) + + if exc_type in {RuntimeError, ValueError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") + + print() + tb = tb.tb_next + + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") + + +sys.excepthook = exception_hook + +###################################################################### + +if __name__ == "__main__": + + import torch + + def dummy(a, b): + print(a @ b) + + def blah(a, b): + c = b + b + dummy(a, c) + + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm) -- 2.20.1