(("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
     (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
     (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
+    # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
 ######################################################################
 
             output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
             dist = torch.distributions.categorical.Categorical(logits=output)
-            input[:, 3 * L :] = dist.sample()
+            input[:, 3 * L + 1 :] = dist.sample()[:, 1:]
+
+            problem.save_quizzes_as_image(
+                args.result_dir,
+                f"thinker_prediction_{n_epoch:04d}.png",
+                quizzes=input,
+                # predicted_parts=predicted_parts,
+                # correct_parts=correct_parts,
+            )
 
 
 ######################################################################