parser.add_argument("--caterpillar_height", type=int, default=None)
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=True)
+
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=True)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
parser.add_argument("--nb_blocks", type=int, default=None)
dropout=args.dropout,
attention_layer=args.attention,
logger=log_string,
- **sup_args,
+ args=args,
)
model.to(device)
nb_train_samples += input.size(0)
nb_samples_seen += input.size(0)
- total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
it += 1
lr = get_lr(n_epoch, it)