Update.
[picoclvr.git] / tasks.py
index 6a7e639..9901715 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
 
 import torch, torchvision
 
@@ -63,7 +63,7 @@ def masked_inplace_autoregression(
 
 
 class Task:
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         pass
 
     def vocabulary_size(self):
@@ -489,7 +489,7 @@ class PicoCLVR(Task):
         self.train_input = self.tensorize(self.train_descr)
         self.test_input = self.tensorize(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -633,9 +633,64 @@ class PicoCLVR(Task):
 ######################################################################
 
 
+def generate_2d_fourier_basis(T):
+    # Create 1D vectors for time/space in both dimensions
+    t = torch.linspace(0, T - 1, T)
+
+    # Initialize an empty list to hold the basis vectors
+    basis = [torch.ones(T, T)]  # The constant (DC) component
+
+    # Generate cosine and sine terms for both dimensions
+    for nx in range(1, T // 2 + 1):
+        for ny in range(1, T // 2 + 1):
+            # Cosine and sine components in x- and y-directions
+            cos_x = torch.cos(2 * math.pi * nx * t / T).unsqueeze(1)
+            sin_x = torch.sin(2 * math.pi * nx * t / T).unsqueeze(1)
+            cos_y = torch.cos(2 * math.pi * ny * t / T).unsqueeze(0)
+            sin_y = torch.sin(2 * math.pi * ny * t / T).unsqueeze(0)
+
+            # Basis functions in 2D as outer products
+            basis.append(torch.mm(cos_x, cos_y))  # cos(nx)cos(ny)
+            basis.append(torch.mm(sin_x, sin_y))  # sin(nx)sin(ny)
+            basis.append(torch.mm(cos_x, sin_y))  # cos(nx)sin(ny)
+            basis.append(torch.mm(sin_x, cos_y))  # sin(nx)cos(ny)
+
+    # Stack the basis into a 3D tensor (number_of_basis_vectors x T x T)
+    basis_matrix = torch.stack(basis[: T * T], dim=0)
+
+    return basis_matrix
+
+
 class MNIST(Task):
+    def create_fourier_basis(self):
+        self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
+        self.fourier_basis_inverse = self.fourier_basis.inverse()
+        y = self.train_input.float() @ self.fourier_basis.t()
+        self.fourier_range = 4
+        self.fourier_mu = y.mean(dim=0, keepdim=True)
+        self.fourier_std = y.std(dim=0, keepdim=True)
+
+    def fourier_encode(self, x):
+        y = x.float() @ self.fourier_basis.t()
+        y = ((y - self.fourier_mu) / self.fourier_std).clamp(
+            min=-self.fourier_range, max=self.fourier_range
+        )
+        y = (((y + self.fourier_range) / (2 * self.fourier_range)) * 255).long()
+        return y
+
+    def fourier_decode(self, y):
+        y = (
+            (y / 255.0) * (2 * self.fourier_range) - self.fourier_range
+        ) * self.fourier_std.to(y.device) + self.fourier_mu.to(y.device)
+        return y.float() @ self.fourier_basis_inverse.to(y.device).t()
+
     def __init__(
-        self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        fourier_representation=True,
+        device=torch.device("cpu"),
     ):
         super().__init__()
 
@@ -648,6 +703,14 @@ class MNIST(Task):
         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
 
+        self.fourier_representation = fourier_representation
+
+        if fourier_representation:
+            self.create_fourier_basis()
+
+            self.train_input = self.fourier_encode(self.train_input)
+            self.test_input = self.fourier_encode(self.test_input)
+
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
@@ -666,6 +729,26 @@ class MNIST(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis
     ):
+        if n_epoch == 0:
+            image_name = os.path.join(result_dir, "fourier.png")
+            torchvision.utils.save_image(
+                0.5
+                - 0.5
+                * self.fourier_basis.reshape(-1, 1, 28, 28)
+                / self.fourier_basis.std(),
+                image_name,
+                nrow=28,
+            )
+
+            image_name = os.path.join(result_dir, "check-train.png")
+            torchvision.utils.save_image(
+                1
+                - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+                / 256,
+                image_name,
+                nrow=16,
+            )
+
         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
         ar_mask = torch.full_like(results, 1)
         masked_inplace_autoregression(
@@ -676,6 +759,10 @@ class MNIST(Task):
             deterministic_synthesis,
             device=self.device,
         )
+
+        if self.fourier_representation:
+            results = self.fourier_decode(results)
+
         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
@@ -754,15 +841,17 @@ class Maze(Task):
     def compute_error(
         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
     ):
+        model_device = next(model.parameters()).device
         nb_total, nb_correct = 0, 0
         count = torch.zeros(
             self.width * self.height,
             self.width * self.height,
-            device=self.device,
+            device=model_device,
             dtype=torch.int64,
         )
 
         for input in self.batches(split, nb_to_use):
+            input = input.to(model_device)
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
@@ -836,7 +925,7 @@ class Maze(Task):
                         eol = " " if j < count.size(1) - 1 else "\n"
                         f.write(f"{count[i,j]}{eol}")
 
-        input = self.test_input[:48]
+        input = self.test_input[:48].to(next(model.parameters()).device)
         result = input.clone()
         ar_mask = result.new_zeros(result.size())
         ar_mask[:, self.height * self.width :] = 1
@@ -1098,6 +1187,34 @@ class Stack(Task):
             device=self.device,
         )
 
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        for label, input in [
+            ("train", self.train_input[:32]),
+            ("test", self.test_input[:32]),
+        ]:
+            output = model(BracketedSequence(input)).x
+            output = output.log_softmax(dim=-1)
+            filename = os.path.join(
+                result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
+            )
+            with open(filename, "w") as f:
+                for n in range(input.size(0)):
+                    s = stack.seq_to_str(
+                        input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
+                    )
+                    for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
+                        u = (
+                            " " * (10 - len(w))
+                            + w
+                            + " "
+                            + str(output[n][t][k].exp().item())
+                            + "\n"
+                        )
+                        f.write(u)
+                    f.write("\n")
+            logger(f"wrote {filename}")
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
         for n in range(result.size(0)):
             logger(
                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
@@ -1685,7 +1802,7 @@ class Grid(Task):
         self.t_nul = self.token2id["#"]
         self.t_true = self.token2id["true"]
         self.t_false = self.token2id["false"]
-        self.t_pipe = self.token2id["|"]
+        self.t_pipe = self.token2id["|"]
 
         # Tokenize the train and test sets
         self.train_input = self.str2tensor(self.train_descr)
@@ -1694,7 +1811,7 @@ class Grid(Task):
             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
         )
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -1823,7 +1940,7 @@ class QMLP(Task):
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -1905,7 +2022,10 @@ class Greed(Task):
             t % self.world.it_len == self.world.index_lookahead_reward
         ).long()
 
-        return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch
+        return (
+            lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
+            + (1 - lr_mask) * batch
+        )
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1941,7 +2061,7 @@ class Greed(Task):
                 progress_bar_desc=None,
             )
             warnings.warn("keeping thinking snapshots", RuntimeWarning)
-            snapshots.append(result[:10].detach().clone())
+            snapshots.append(result[:100].detach().clone())
 
         # Generate iteration after iteration
 
@@ -1950,7 +2070,7 @@ class Greed(Task):
         result[:, self.world.it_len :] = -1
         # Set the lookahead_reward of the firs to UNKNOWN
         result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
-            2
+            greed.REWARD_UNKNOWN
         )
 
         t = torch.arange(result.size(1), device=result.device)[None, :]
@@ -1965,7 +2085,7 @@ class Greed(Task):
             if u > 0:
                 result[
                     :, u + self.world.index_lookahead_reward
-                ] = self.world.lookahead_reward2code(2)
+                ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
                 ar_mask = (t >= u + self.world.index_states).long() * (
                     t < u + self.world.index_states + self.world.state_len
                 ).long()
@@ -1974,7 +2094,7 @@ class Greed(Task):
             # Generate the action and reward with lookahead_reward to +1
             result[
                 :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(1)
+            ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
             ar_mask = (t >= u + self.world.index_reward).long() * (
                 t <= u + self.world.index_action
             ).long()
@@ -1983,11 +2103,11 @@ class Greed(Task):
             # Set the lookahead_reward to UNKNOWN for the next iterations
             result[
                 :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(2)
+            ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
 
         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
-            for n in range(10):
+            for n in range(snapshots[0].size(0)):
                 for s in snapshots:
                     lr, s, a, r = self.world.seq2episodes(
                         s[n : n + 1],