x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
 
-    return result
+    return x_t_minus_1
 
 
 ######################################################################
 
         hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
-        hat_x_t_minus_1 = one_iteration_prediction * x_0 + (
+        hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
             1 - one_iteration_prediction
         ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
 
 
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
         nb_disagreements = 0
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
     for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise