parser.add_argument("--test", type=str, default=None)
 
+parser.add_argument("--logit_std_max", type=float, default=-1)
+
 ######################################################################
 
 grids_tasks = ", ".join(
         dropout=args.dropout,
     ).to(main_device)
 
+    class UpperBoundStd(nn.Module):
+        def __init__(self, std_max=1.0):
+            super().__init__()
+            self.std_max = std_max
+
+        def forward(self, x):
+            std = x.std(dim=-1, keepdim=True)
+            y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
+            return y
+
+    if args.logit_std_max > 0:
+        model.readout.f = nn.Sequential(
+            model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
+        )
+
     model.id = k
     model.train_c_quiz_bags = []
     model.test_c_quiz_bags = []
 
 ######################################################################
 
+
 if args.test == "entropy":
     model = models[0]
     model.to(main_device)
 
-    log_string("starting testing entropy maximization")
-
-    train_input = quiz_machine.generate_c_quizzes(
-        1000,
-        model_for_generation=model,
-        procedure=c_quizzes_procedure,
-    )
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
 
-    for n_epoch in range(10):
-        nb_train_samples, acc_train_loss = 0, 0.0
+    log_string("starting testing entropy maximization")
 
-        for input in train_input.split(args.batch_size):
-            input = input.to(main_device)
-            output = model(mygpt.BracketedSequence(input)).x
-            loss = output.log_softmax(dim=1).mean()
+    for n_epoch in range(100):
+        input = quiz_machine.generate_c_quizzes(
+            128,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
+        )
 
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"test_{n_epoch:04d}.png",
+            quizzes=input,
+        )
 
-            model.optimizer.zero_grad()
-            loss.backward()
-            model.optimizer.step()
+        log_string(f"wrote {filename}")
 
-        log_string(
-            f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
-        )
+        with torch.no_grad():
+            for p in model.parameters():
+                p += torch.randn(p.size(), device=p.device) * 1e-3
+
+        # nb_train_samples, acc_train_loss = 0, 0.0
+
+        # for k in range(1000 // args.batch_size):
+        # input = quiz_machine.generate_c_quizzes(
+        # args.batch_size,
+        # model_for_generation=model,
+        # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
+        # )
+
+        # input = input.to(main_device)
+        # targets = input
+        # output = model(mygpt.BracketedSequence(input)).x
+        # loss = -F.cross_entropy(output.transpose(1, 2), targets)
+        # acc_train_loss += loss.item() * input.size(0)
+        # nb_train_samples += input.size(0)
+
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+
+        # log_string(
+        # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
+        # )
 
     exit(0)