Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index c51035c..d6845e8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,14 +16,6 @@ import mygpt, tasks, problems
 
 ######################################################################
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
-
-######################################################################
-
 
 def str2bool(x):
     x = x.lower()
@@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
 ########################################
 
 parser.add_argument("--nb_epochs", type=int, default=50)
@@ -93,6 +87,10 @@ parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--attention", type=str, default=None)
 
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=1)
+
 parser.add_argument("--dim_model", type=int, default=None)
 
 parser.add_argument("--dim_keys", type=int, default=None)
@@ -105,7 +103,13 @@ parser.add_argument("--nb_lines", type=int, default=None)
 
 parser.add_argument("--caterpillar_height", type=int, default=None)
 
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
+
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
@@ -139,6 +143,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 parser.add_argument("--grid_size", type=int, default=6)
 
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
 ##############################
 # picoclvr options
 
@@ -208,15 +216,25 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default
 
 ######################################################################
 
-args = parser.parse_args()
+args = parser.parse_args()
 
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
 
 if args.result_dir is None:
     args.result_dir = f"results_{args.task}_{args.model}"
 
 ######################################################################
 
+if not args.force_cpu and torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
 default_task_args = {
     "addition": {
         "model": "352M",
@@ -430,6 +448,8 @@ except FileExistsError:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
 if args.seed >= 0:
@@ -466,6 +486,9 @@ log_string(f"argv {' '.join(sys.argv)}")
 for n in vars(args):
     log_string(f"args.{n} {getattr(args, n)}")
 
+for k, v in sup_args.items():
+    log_string(f'sup_args["{k}"] "{v}"')
+
 
 ######################################################################
 
@@ -503,6 +526,9 @@ def get_lr(n_epoch, it):
 ######################################################################
 
 
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -689,6 +715,8 @@ elif args.task == "grid":
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
         size=args.grid_size,
+        nb_shapes=args.grid_nb_shapes,
+        nb_colors=args.grid_nb_colors,
         logger=log_string,
         device=device_data,
     )
@@ -712,6 +740,9 @@ log_string(f"device {device}")
 
 vocabulary_size = task.vocabulary_size()
 
+if args.memex_proba > 0:
+    vocabulary_size += 1
+
 log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
@@ -728,6 +759,8 @@ model = mygpt.MyGPT(
     causal=True,
     dropout=args.dropout,
     attention_layer=args.attention,
+    logger=log_string,
+    args=args,
 )
 
 model.to(device)
@@ -821,6 +854,25 @@ if args.max_percents_of_test_in_train >= 0:
 
 ##############################
 
+if "calibrate" in sup_args:
+    for input in task.batches(split="train", desc="calibrate"):
+        input = input.to(device)
+        output = model(mygpt.BracketedSequence(input)).x
+
+    for n, m in model.named_modules():
+        for a in dir(m):
+            x = getattr(m, a)
+            if isinstance(x, mygpt.Calibrator):
+                print(f"####### ${n} | ${a} ########################")
+                mean, std = x.moments()
+                print("mean\n", mean, "\n")
+                print("std\n", std, "\n")
+                print(f"############################################\n\n")
+
+    exit(0)
+
+##############################
+
 nb_samples_seen = 0
 
 if nb_epochs_finished >= nb_epochs:
@@ -832,10 +884,12 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
 
 it = 0
 
+n_batch = 0
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     if args.optim == "sgd":
         optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@ -850,7 +904,29 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
 
-    for input in task.batches(split="train"):
+    def add_memex(batches, memex_proba):
+        for input in batches:
+            if torch.rand(1).item() < memex_proba:
+                sep = torch.full(
+                    (input.size(0), 1), vocabulary_size - 1, device=input.device
+                )
+
+                yield torch.cat(
+                    [
+                        input,
+                        sep,
+                        input,
+                    ],
+                    dim=1,
+                )
+            yield input
+
+    train_batches = add_memex(
+        task.batches(split="train"),
+        args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+    )
+
+    for input in train_batches:
         model.reset_inner_loss()
         input = input.to(device)
 
@@ -864,7 +940,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         nb_train_samples += input.size(0)
         nb_samples_seen += input.size(0)
 
-        total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+        total_loss = loss + (
+            args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+        )
 
         it += 1
         lr = get_lr(n_epoch, it)
@@ -877,6 +955,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         total_loss.backward()
         optimizer.step()
 
+        grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+
+        loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+
+        n_batch += 1
+
     with torch.autograd.no_grad():
         model.eval()
 
@@ -910,10 +994,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
         time_pred_result = time_current_result
 
     checkpoint = {