parser.add_argument("--caterpillar_height", type=int, default=None)
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--grid_size", type=int, default=6)
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
##############################
# picoclvr options
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
size=args.grid_size,
+ nb_shapes=args.grid_nb_shapes,
+ nb_colors=args.grid_nb_colors,
logger=log_string,
device=device_data,
)
dropout=args.dropout,
attention_layer=args.attention,
logger=log_string,
- **sup_args,
+ args=args,
)
model.to(device)
##############################
-for input in task.batches(split="train", desc="calibrate"):
- input = input.to(device)
- output = model(mygpt.BracketedSequence(input)).x
+if "calibrate" in sup_args:
+ for input in task.batches(split="train", desc="calibrate"):
+ input = input.to(device)
+ output = model(mygpt.BracketedSequence(input)).x
-for n, m in model.named_modules():
- for a in dir(m):
- x = getattr(m, a)
- if isinstance(x, mygpt.Calibrator):
- print(f"####### ${n} | ${a} ########################")
- mean, std = x.moments()
- print("mean\n", mean, "\n")
- print("std\n", std, "\n")
- print(f"############################################\n\n")
+ for n, m in model.named_modules():
+ for a in dir(m):
+ x = getattr(m, a)
+ if isinstance(x, mygpt.Calibrator):
+ print(f"####### ${n} | ${a} ########################")
+ mean, std = x.moments()
+ print("mean\n", mean, "\n")
+ print("std\n", std, "\n")
+ print(f"############################################\n\n")
-exit(0)
+ exit(0)
##############################
nb_train_samples += input.size(0)
nb_samples_seen += input.size(0)
- total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
it += 1
lr = get_lr(n_epoch, it)