parser.add_argument("--nb_test_samples", type=int, default=1000)
 
+parser.add_argument("--nb_train_alien_samples", type=int, default=0)
+
+parser.add_argument("--nb_test_alien_samples", type=int, default=0)
+
 parser.add_argument("--nb_c_quizzes", type=int, default=2500)
 
 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
     logger=log_string,
     device=main_device,
 )
-
 # ------------------------------------------------------
 
 ######################################################################
                         subparam._grad.data = subparam._grad.data.to(device)
 
 
-######################################################################
-
-from mygpt import (
-    WithResidual,
-    CacheWrapper,
-    VaswaniPositionalEncoding,
-    TrainablePositionalEncoding,
-    QKVAttention,
-    BracketedSequence,
-)
-
-
-class Thinker(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        f_len,
-        dropout=0.0,
-        len_max=1e5,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            VaswaniPositionalEncoding(len_max),
-        )
-
-        def trunk(depth):
-            trunk_blocks = []
-
-            for b in range(nb_blocks):
-                trunk_blocks += [
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                        ),
-                        QKVAttention(
-                            dim_in=dim_model,
-                            dim_qk=dim_keys,
-                            dim_v=dim_model // nb_heads,
-                            nb_heads=nb_heads,
-                            attention_dropout=dropout,
-                        ),
-                    ),
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                            nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                            nn.ReLU(),
-                            nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                            nn.Dropout(dropout),
-                        ),
-                    ),
-                ]
-
-            return nn.Sequential(*trunk_blocks)
-
-        self.bottom_trunk = trunk(nb_blocks // 2)
-
-        self.top_trunk = trunk(nb_blocks // 2)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        for m in self.modules():
-            m.loss = 0
-
-        L = bs.x.size(1) // 3
-
-        bs = self.embedding(bs)
-        A_fA = BracketedSequence(bs.x[:, : 2 * L])
-        B = BracketedSequence(bs.x[:, -L:])
-
-        bs = BracketedSequence(
-            torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
-        )
-        bs = self.bottom_trunk(bs)
-        bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
-        bs = self.top_trunk(bs)
-        bs = BracketedSequence(bs.x[:, f_len:, :])
-        bs = self.readout(bs)
-
-        for m in self.modules():
-            if m is not self:
-                self.loss += m.loss
-
-        return bs
-
-
 ######################################################################
 
 
 from mygpt import (
     WithResidual,
     CacheWrapper,
-    VaswaniPositionalEncoding,
+    CachedVaswaniPositionalEncoding,
     QKVAttention,
     BracketedSequence,
 )
         )
 
         # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
-        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+        self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5)
 
         trunk_blocks = []
 
         return bs
 
 
-######################################################################
-
-# f = phi(A, f(A)) + phi(B, f(B))
-# \hat{f(A)} = psi(A, f)
-# \hat{A} = psi_inv(f(A), f)
-# \hat{f(B)} = psi(B, f)
-# \hat{B} = psi_inv(f(B), f)
-
-
-def attention_layer(dim_model, dim_keys, nb_heads, dropout):
-    return WithResidual(
-        CacheWrapper(
-            nn.LayerNorm((dim_model,)),
-        ),
-        QKVAttention(
-            dim_in=dim_model,
-            dim_qk=dim_keys,
-            dim_v=dim_model // nb_heads,
-            nb_heads=nb_heads,
-            attention_dropout=dropout,
-        ),
-    )
-
-
-class FunctionalAE(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        dropout=0.0,
-        len_max=1024,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = CacheWrapper(
-            nn.Sequential(
-                MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
-            ),
-        )
-
-        # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
-        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
-
-        def trunk(nb, bottom=True):
-            trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)]
-
-            la = [
-                QKVAttention(
-                    dim_in=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    attention_dropout=dropout,
-                ),
-            ]
-
-            # if not bottom:
-            # trunk_blocks += la
-
-            for b in range(nb):
-                trunk_blocks += [
-                    attention_block(dim_model, dim_keys, nb_heads, dropout),
-                    ffw_block(dim_model, dim_hidden, nb_heads, dropout),
-                ]
-
-            # if bottom:
-            # trunk_blocks += la
-
-            return nn.Sequential(*trunk_blocks)
-
-        self.phi = trunk(nb_blocks // 2, bottom=True)
-        nb_f_tokens = 200
-        self.f_tokens = nn.Parameter(
-            torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens)
-        )
-        self.psi = trunk(nb_blocks // 2, bottom=False)
-        self.psi_inv = trunk(nb_blocks // 2, bottom=False)
-        self.internal_pe = VaswaniPositionalEncoding(len_max=1e5)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        def cat(*x):
-            return BracketedSequence(torch.cat(x, dim=1))
-
-        if torch.is_tensor(bs):
-            return self.forward(BracketedSequence(bs)).x
-        bs = self.embedding(bs)
-        bs = self.positional_encoding(bs)
-
-        x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1)
-
-        K = self.f_tokens.size(1)
-        N, L = x_A.size()[:2]
-
-        ft = self.f_tokens.expand(N, -1, -1)
-
-        theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
-        theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
-
-        # if self.hook_theta is not None:
-        # self.hook_theta(theta_A, theta_B)
-
-        hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
-        hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
-
-        hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L]
-        hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L]
-
-        bs = cat(hat_A, hat_f_A, hat_B, hat_f_B)
-
-        bs = self.readout(bs)
-        return bs
-
-
 ######################################################################
 
 # quad_order, quad_generate, quad_noise, quad_loss
     data_structures,
     local_device,
     c_quizzes=None,
+    alien_quiz_machine=None,
+    nb_aliens=None,
     desc=None,
     batch_size=args.batch_size,
 ):
             f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
         )
 
-        model.test_accuracy = nb_correct / nb_total
-
         # Save some images
 
-        for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            result, predicted_parts, correct_parts = bag_to_tensors(record)
+        if n_epoch < 50:
+            for f, record in [("prediction", record_d), ("generation", record_nd)]:
+                result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+                filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=result[:128],
-                predicted_parts=predicted_parts[:128],
-                correct_parts=correct_parts[:128],
-            )
+                quiz_machine.problem.save_quizzes_as_image(
+                    args.result_dir,
+                    filename,
+                    quizzes=result[:128],
+                    predicted_parts=predicted_parts[:128],
+                    correct_parts=correct_parts[:128],
+                )
 
-            log_string(f"wrote {filename}")
+                log_string(f"wrote {filename}")
+
+        return nb_correct / nb_total
 
 
 ######################################################################
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device)
+    model.test_accuracy = run_ae_test(
+        model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
+    )
+
+    if args.nb_test_alien_samples > 0:
+        run_ae_test(
+            model,
+            alien_quiz_machine,
+            n_epoch,
+            c_quizzes=None,
+            local_device=local_device,
+            prefix="alien",
+        )
 
 
 ######################################################################
 
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
+
+    def copy_for_inference(model):
+        return copy.deepcopy(model).to(local_device).eval()
+
     quad_order = ("A", "f_A", "B", "f_B")
 
     template = quiz_machine.problem.create_empty_quizzes(
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    def copy_for_inference(model):
-        return copy.deepcopy(model).to(local_device).eval()
-
     wanted_nb = nb
     nb_to_save = 256
     nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)