##############################
 
 
+class NoiseInjector(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.noise_std = 0.0
+
+    def forward(self, x):
+        if self.noise_std > 0:
+            x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+        return x
+
+
+def set_noise_injection(model, noise_std):
+    for m in model.modules():
+        if isinstance(m, NoiseInjector):
+            m.noise_std = noise_std
+
+
+##############################
+
+
 class MyGPT(nn.Module):
     def __init__(
         self,
         for b in range(nb_blocks):
             trunk_blocks += [
                 WithResidual(
-                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        NoiseInjector(),
+                    ),
                     QKVAttention(
                         dim_in=dim_model,
                         dim_qk=dim_keys,
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
+                        NoiseInjector(),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
 
 from torch import nn
 from torch.nn import functional as F
 
+import mygpt
 from mygpt import BracketedSequence
 
 ######################################################################
 class Gang(nn.Module):
     def __init__(self, models, nb_models_for_generation, mode="groupthink"):
         super().__init__()
-        self.models = models
+        self.models = nn.ModuleList(models)
         self.nb_models_for_generation = nb_models_for_generation
         self.mode = mode
 
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        # bracketing of the temperature to get the target logproba if
-        # min_ave_seq_logproba is not None
+        warnings.warn("noise injection", RuntimeWarning)
+        temperature = 1
+        noise_std = torch.rand(1).item()
+        self.logger(f"{noise_std=}")
+        mygpt.set_noise_injection(model_for_generation, noise_std)
 
-        temperature = 2
-        d_temperature = 1 / 3
-
-        while True:
-            seq_logproba[...] = 0
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_prompt,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=False,
-                # progress_bar_desc="sampling c_quizzes",
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_prompt,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=False,
+            # progress_bar_desc="sampling c_quizzes",
+            device=self.device,
+        )
 
-            ave_seq_logproba = seq_logproba.mean()
+        ave_seq_logproba = seq_logproba.mean()
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                # progress_bar_desc="sampling c_quizzes",
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_solve,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=True,
+            # progress_bar_desc="sampling c_quizzes",
+            device=self.device,
+        )
 
-            # If we do not have target logprobs, get out now
-            if min_ave_seq_logproba is None:
-                break
-
-            # Oh man that's ugly
-            if ave_seq_logproba < min_ave_seq_logproba:
-                if d_temperature > 0:
-                    d_temperature *= -1 / 3
-                temperature += d_temperature
-            elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
-                if d_temperature < 0:
-                    d_temperature *= -1 / 3
-                temperature += d_temperature
-            else:
-                break
-
-            self.logger(f"changing temperature to {temperature}")
+        mygpt.set_noise_injection(model_for_generation, 0.0)
 
         return c_quizzes, seq_logproba.mean()
 
 
 
     def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
+        nb_frames = (self.nb_iterations - 1) * self.speed + 1
 
         result = torch.full(
-            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            (nb * 4, nb_frames, self.height, self.width),
             self.token_empty,
         )
 
                         result[n, 0, i + vi, j + vj] = self.token_tail
                         break
 
-                if torch.rand(1) < 0.75:
-                    break
+                # if torch.rand(1) < 0.75:
+                break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations * self.speed - 1):
+        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+        valid = nb_heads > 0
+
+        for l in range(nb_frames - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
                     + (1 - mask_1_or_2_heads) * self.token_conductor
                 )
             )
+            pred_nb_heads = nb_heads
+            nb_heads = (
+                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+            )
+            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+        result = result[valid]
 
         result = result[
             :, torch.arange(self.nb_iterations, device=result.device) * self.speed