parser.add_argument("--nb_models", type=int, default=5)
 
+parser.add_argument("--proba_plasticity", type=float, default=0.0)
+
 parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
 parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
 ######################################################################
 
 
-def new_model(i):
+def new_model(id=-1):
     if args.model_type == "standard":
         model_constructor = attae.AttentionAE
     elif args.model_type == "functional":
         dropout=args.dropout,
     )
 
-    model.id = i
+    model.id = id
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
     model.test_accuracy = 0.0
     model.nb_epochs = 0
 
 ######################################################################
 
+
+def inject_plasticity(model, proba):
+    if proba <= 0:
+        return
+
+    dummy = new_model()
+
+    with torch.no_grad():
+        for p, q in zip(mode.parameters(), dummy.parameters()):
+            mask = (torch.rand(p.size()) <= proba).long()
+            p[...] = (1 - mask) * p + mmask * q
+
+
+######################################################################
+
+chunk_size = 100
+
 problem = grids.Grids(
-    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-    chunk_size=100,
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // chunk_size,
+    chunk_size=chunk_size,
     nb_threads=args.nb_threads,
     tasks=args.grids_world_tasks,
 )
 
 for i in range(args.nb_models):
     model = new_model(i)
-    # model = torch.compile(model)
+    model = torch.compile(model)
 
     models.append(model)
 
         test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
 
         for model in models:
+            inject_plasticity(model, args.proba_plasticity)
             model.test_accuracy = 0
 
     if train_c_quizzes is None: