Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 79841f3..3aa696b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -99,7 +99,11 @@ parser.add_argument("--nb_lines", type=int, default=None)
 
 parser.add_argument("--caterpillar_height", type=int, default=None)
 
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
@@ -747,7 +751,7 @@ model = mygpt.MyGPT(
     dropout=args.dropout,
     attention_layer=args.attention,
     logger=log_string,
-    **sup_args,
+    args=args,
 )
 
 model.to(device)
@@ -905,7 +909,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         nb_train_samples += input.size(0)
         nb_samples_seen += input.size(0)
 
-        total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+        total_loss = loss + (
+            args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+        )
 
         it += 1
         lr = get_lr(n_epoch, it)