######################################################################
 
 
+class AdHocPositionalEncoding(nn.Module):
+    def __init__(self, dim_model, value, trainable=False):
+        super().__init__()
+        if trainable:
+            self.value = nn.Parameter(value.clone())
+        else:
+            self.register_buffer("value", value.clone())
+        self.fc = nn.Linear(
+            in_features=value.size(-1) + dim_model, out_features=dim_model
+        )
+
+    def forward(self, x):
+        value = self.value[None, :, :].repeat(x.size(0), 1, 1)
+        x = torch.cat([value, x], dim=2)
+        y = self.fc(x)
+        return y
+
+
+######################################################################
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
 
 
     problem.save_quizzes_as_image(
         args.result_dir,
-        f"culture_prediction_{n_epoch}_{model.id}.png",
+        f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
         quizzes=result[:128],
         predicted_parts=predicted_parts[:128],
         correct_parts=correct_parts[:128],
 
     problem.save_quizzes_as_image(
         args.result_dir,
-        f"culture_generation_{n_epoch}_{model.id}.png",
+        f"culture_generation_{n_epoch:04d}_{model.id:02d}.png",
         quizzes=result[:128],
     )
 
         len_max=1e4,
     )
 
-    model.positional_encoding = attae.BlockRandomPositionalEncoding(
-        args.dim_model, 100, 4
-    )
+    # model.positional_encoding = attae.BlockRandomPositionalEncoding(
+    # args.dim_model, 100, 4
+    # )
+
+    i = torch.arange(400)[:, None]
+    k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200]
+    k = torch.tensor(k)[None, :]
+    pe = (i // k) % 2
+
+    model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe)
 
     model.trunk = attae.Reasoning(
-        nb_f_tokens=25,
+        nb_f_tokens=8,
         nb_chunks=2,
         dim_model=args.dim_model,
         dim_qk=args.dim_keys,