for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+        nv = " ".join([str(x.item()) for x in nv])
 
         log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
 
 
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
-        temperature = 10
+        if reverse_cleanup:
+            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+            temperature = 10.0
+        else:
+            temperature = 1.0
 
         # warnings.warn("noise injection", RuntimeWarning)
         # noise_std = torch.rand(1).item()