X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=c0ad5ffb08683de9027cc71e8335636f7096af64;hb=HEAD;hp=324376df60319e9549ae431c5d43dd04f1a29ed9;hpb=232299b8af7e66a02e64bb2e47b525e2f50b099d;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 324376d..9901715 100755 --- a/tasks.py +++ b/tasks.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -63,7 +63,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): pass def vocabulary_size(self): @@ -489,7 +489,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -633,9 +633,64 @@ class PicoCLVR(Task): ###################################################################### +def generate_2d_fourier_basis(T): + # Create 1D vectors for time/space in both dimensions + t = torch.linspace(0, T - 1, T) + + # Initialize an empty list to hold the basis vectors + basis = [torch.ones(T, T)] # The constant (DC) component + + # Generate cosine and sine terms for both dimensions + for nx in range(1, T // 2 + 1): + for ny in range(1, T // 2 + 1): + # Cosine and sine components in x- and y-directions + cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1) + sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1) + cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0) + sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0) + + # Basis functions in 2D as outer products + basis.append(torch.mm(cos_x, cos_y)) # cos(nx)cos(ny) + basis.append(torch.mm(sin_x, sin_y)) # sin(nx)sin(ny) + basis.append(torch.mm(cos_x, sin_y)) # cos(nx)sin(ny) + basis.append(torch.mm(sin_x, cos_y)) # sin(nx)cos(ny) + + # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T) + basis_matrix = torch.stack(basis[: T * T], dim=0) + + return basis_matrix + + class MNIST(Task): + def create_fourier_basis(self): + self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1) + self.fourier_basis_inverse = self.fourier_basis.inverse() + y = self.train_input.float() @ self.fourier_basis.t() + self.fourier_range = 4 + self.fourier_mu = y.mean(dim=0, keepdim=True) + self.fourier_std = y.std(dim=0, keepdim=True) + + def fourier_encode(self, x): + y = x.float() @ self.fourier_basis.t() + y = ((y - self.fourier_mu) / self.fourier_std).clamp( + min=-self.fourier_range, max=self.fourier_range + ) + y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long() + return y + + def fourier_decode(self, y): + y = ( + (y / 255.0) * (2 * self.fourier_range) - self.fourier_range + ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device) + return y.float() @ self.fourier_basis_inverse.to(y.device).t() + def __init__( - self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") + self, + nb_train_samples, + nb_test_samples, + batch_size, + fourier_representation=True, + device=torch.device("cpu"), ): super().__init__() @@ -648,6 +703,14 @@ class MNIST(Task): data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True) self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long() + self.fourier_representation = fourier_representation + + if fourier_representation: + self.create_fourier_basis() + + self.train_input = self.fourier_encode(self.train_input) + self.test_input = self.fourier_encode(self.test_input) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -666,6 +729,26 @@ class MNIST(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): + if n_epoch == 0: + image_name = os.path.join(result_dir, "fourier.png") + torchvision.utils.save_image( + 0.5 + - 0.5 + * self.fourier_basis.reshape(-1, 1, 28, 28) + / self.fourier_basis.std(), + image_name, + nrow=28, + ) + + image_name = os.path.join(result_dir, "check-train.png") + torchvision.utils.save_image( + 1 + - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28) + / 256, + image_name, + nrow=16, + ) + results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64) ar_mask = torch.full_like(results, 1) masked_inplace_autoregression( @@ -676,6 +759,10 @@ class MNIST(Task): deterministic_synthesis, device=self.device, ) + + if self.fourier_representation: + results = self.fourier_decode(results) + image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, @@ -754,15 +841,17 @@ class Maze(Task): def compute_error( self, model, split="train", nb_to_use=-1, deterministic_synthesis=False ): + model_device = next(model.parameters()).device nb_total, nb_correct = 0, 0 count = torch.zeros( self.width * self.height, self.width * self.height, - device=self.device, + device=model_device, dtype=torch.int64, ) for input in self.batches(split, nb_to_use): + input = input.to(model_device) result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 @@ -836,7 +925,7 @@ class Maze(Task): eol = " " if j < count.size(1) - 1 else "\n" f.write(f"{count[i,j]}{eol}") - input = self.test_input[:48] + input = self.test_input[:48].to(next(model.parameters()).device) result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 @@ -1098,6 +1187,34 @@ class Stack(Task): device=self.device, ) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + for label, input in [ + ("train", self.train_input[:32]), + ("test", self.test_input[:32]), + ]: + output = model(BracketedSequence(input)).x + output = output.log_softmax(dim=-1) + filename = os.path.join( + result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt" + ) + with open(filename, "w") as f: + for n in range(input.size(0)): + s = stack.seq_to_str( + input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits + ) + for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")): + u = ( + " " * (10 - len(w)) + + w + + " " + + str(output[n][t][k].exp().item()) + + "\n" + ) + f.write(u) + f.write("\n") + logger(f"wrote {filename}") + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + for n in range(result.size(0)): logger( f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" @@ -1685,7 +1802,7 @@ class Grid(Task): self.t_nul = self.token2id["#"] self.t_true = self.token2id["true"] self.t_false = self.token2id["false"] - self.t_pipe = self.token2id["|"] + # self.t_pipe = self.token2id["|"] # Tokenize the train and test sets self.train_input = self.str2tensor(self.train_descr) @@ -1694,7 +1811,7 @@ class Grid(Task): None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr) ) - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -1823,7 +1940,7 @@ class QMLP(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -1944,7 +2061,7 @@ class Greed(Task): progress_bar_desc=None, ) warnings.warn("keeping thinking snapshots", RuntimeWarning) - snapshots.append(result[:10].detach().clone()) + snapshots.append(result[:100].detach().clone()) # Generate iteration after iteration @@ -1986,11 +2103,11 @@ class Greed(Task): # Set the lookahead_reward to UNKNOWN for the next iterations result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN) + ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: - for n in range(10): + for n in range(snapshots[0].size(0)): for s in snapshots: lr, s, a, r = self.world.seq2episodes( s[n : n + 1],