X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=a587e967af4fb9ec05fdca261f11decc55df8ac4;hb=HEAD;hp=c51035c118f6480f2f33560b8aeea020c4cfbf28;hpb=ffe183868ac8563fd82fc8312fda90f6f8a95833;p=mygptrnn.git diff --git a/main.py b/main.py index c51035c..a587e96 100755 --- a/main.py +++ b/main.py @@ -11,19 +11,13 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + import ffutils import mygpt, tasks, problems ###################################################################### -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") - -###################################################################### - def str2bool(x): x = x.lower() @@ -55,11 +49,15 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--force_cpu", type=str2bool, default=False) + ######################################## -parser.add_argument("--nb_epochs", type=int, default=50) +parser.add_argument("--nb_epochs", type=int, default=25) + +parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=None) @@ -93,6 +91,10 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--memex_proba", type=float, default=0) + +parser.add_argument("--memex_nb_epochs", type=float, default=None) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -105,7 +107,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -139,6 +147,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -208,109 +220,119 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default ###################################################################### -args = parser.parse_args() +# args = parser.parse_args() -assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +args, sup_args = parser.parse_known_args() + +sup_args = dict([x.removeprefix("--").split("=") for x in sup_args]) if args.result_dir is None: args.result_dir = f"results_{args.task}_{args.model}" ###################################################################### +if not args.force_cpu and torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + default_task_args = { "addition": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "byheart": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "expr": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "grid": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "qmlp": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 10, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "guessop": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 1000000, "nb_test_samples": 10000, }, "learnop": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "maze": { "model": "37M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "rpl": { "model": "352M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "snake": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "stack": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "twotargets": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "memory": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 25000, "nb_test_samples": 10000, }, "mixing": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 5, "nb_train_samples": 60000, "nb_test_samples": 10000, }, @@ -430,6 +452,9 @@ except FileExistsError: print(f"result directory {args.result_dir} already exists") exit(1) +loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a") +lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a") + log_file = open(os.path.join(args.result_dir, args.log_filename), "a") if args.seed >= 0: @@ -466,6 +491,9 @@ log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): log_string(f"args.{n} {getattr(args, n)}") +for k, v in sup_args.items(): + log_string(f'sup_args["{k}"] "{v}"') + ###################################################################### @@ -503,6 +531,133 @@ def get_lr(n_epoch, it): ###################################################################### +def add_memex_v1(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(1 + 2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + memex_mask = new_input.new_zeros(new_input.size()) + memex_mask[:, input.size(1) :] = 1.0 + + yield new_input, memex_mask + + yield input + + +# The marker token is not used for this one +def add_memex_v2(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand( + input.size(0), -1 + ) + t = t + torch.randint( + input.size(1) - t.size(1), (t.size(0), 1), device=t.device + ) + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + flash = input[n, t] + new_input = torch.cat([input, flash], dim=1) + + memex_mask = new_input.new_zeros(new_input.size()) + memex_mask[:, input.size(1) :] = 1.0 + + yield new_input, memex_mask + + else: + yield input + + +def add_memex_v3(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + memex_len = input.size(1) // 4 + + t = torch.arange(input.size(1) + memex_len, device=input.device)[ + None, : + ].expand(input.size(0), -1) + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + # Call me the tensor-spaghetti master + + trigger = torch.rand(t.size(), device=t.device) + trigger[:, -memex_len:] = 2.0 + trigger[:, 0] = 2.0 + trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long() + memex_mask = trigger.clone() + memex_mask[:, memex_len:] -= trigger[:, :-memex_len] + memex_mask = memex_mask.cumsum(dim=1) + + u = 1 - memex_mask + u[:, 0] = 0 + u = u.cumsum(dim=1) + assert u.min() == 0 + assert u.max() == input.size(1) - 1 + + v = ( + (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) + + torch.randint( + input.size(1) - memex_len, (input.size(0), 1), device=t.device + ) + ) * memex_mask + assert v.min() >= 0 + assert v.max() < input.size(1) + u = u * (1 - memex_mask) + v * memex_mask + + new_input = input[n, u] + assert input.max() < vocabulary_size + assert new_input.max() < vocabulary_size + limits = trigger.clone() + limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] + assert limits.min() == 0 + assert limits.max() == 1 + new_input = new_input * (1 - limits) + marker_token * limits + assert marker_token < vocabulary_size + assert new_input.max() < vocabulary_size + + yield new_input, memex_mask + + else: + yield input + + +###################################################################### + +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + +assert args.batch_size % args.physical_batch_size == 0 + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -528,7 +683,7 @@ if args.task == "byheart": problem=problems.ProblemByHeart(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -539,7 +694,7 @@ elif args.task == "learnop": problem=problems.ProblemLearnOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -550,7 +705,7 @@ elif args.task == "guessop": problem=problems.ProblemGuessOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -561,7 +716,7 @@ elif args.task == "twotargets": problem=problems.ProblemTwoTargets(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -571,7 +726,7 @@ elif args.task == "memory": problem=problems.ProblemMemory(len_total=args.memory_len_total), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -583,7 +738,7 @@ elif args.task == "mixing": ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -593,7 +748,7 @@ elif args.task == "addition": problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -602,7 +757,7 @@ elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.picoclvr_height, width=args.picoclvr_width, nb_colors=args.picoclvr_nb_colors, @@ -616,7 +771,7 @@ elif args.task == "mnist": task = tasks.MNIST( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -624,7 +779,7 @@ elif args.task == "maze": task = tasks.Maze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.maze_height, width=args.maze_width, nb_walls=args.maze_nb_walls, @@ -635,7 +790,7 @@ elif args.task == "snake": task = tasks.Snake( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.snake_height, width=args.snake_width, nb_colors=args.snake_nb_colors, @@ -648,7 +803,7 @@ elif args.task == "stack": task = tasks.Stack( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, @@ -665,7 +820,7 @@ elif args.task == "expr": sequence_length=args.expr_sequence_length, operand_max=args.expr_operand_max, result_max=args.expr_result_max, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -673,7 +828,7 @@ elif args.task == "rpl": task = tasks.RPL( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, nb_starting_values=args.rpl_nb_starting_values, max_input=args.rpl_max_input, prog_len=args.rpl_prog_len, @@ -687,8 +842,10 @@ elif args.task == "grid": task = tasks.Grid( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -697,7 +854,7 @@ elif args.task == "qmlp": task = tasks.QMLP( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, device=device_data, @@ -712,6 +869,10 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.memex_proba > 0: + memex_marker = vocabulary_size + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -728,6 +889,8 @@ model = mygpt.MyGPT( causal=True, dropout=args.dropout, attention_layer=args.attention, + logger=log_string, + args=args, ) model.to(device) @@ -821,6 +984,25 @@ if args.max_percents_of_test_in_train >= 0: ############################## +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + + exit(0) + +############################## + nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: @@ -832,10 +1014,38 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) -time_pred_result = None +time_pred_result = datetime.datetime.now() it = 0 +n_batch = 0 + + +def the_dot_products(value1, value2, params): + g1g1, g1g2, g2g2 = 0, 0, 0 + for p in params: + g1 = torch.autograd.grad(value1, p, retain_graph=True)[0] + g2 = torch.autograd.grad(value2, p, retain_graph=True)[0] + g1g1 += g1.pow(2).sum()[None] + g2g2 += g2.pow(2).sum()[None] + g1g2 += (g1 * g2).sum()[None] + return torch.cat([g1g1, g1g2, g2g2]) + + +def update_ave_grad(value, params, name, eps=1e-3): + for p in params: + g = torch.autograd.grad(value, p, retain_graph=True)[0] + ag = getattr(p, name) if hasattr(p, name) else 0 + setattr(p, name, (1 - eps) * ag + eps * g) + + +def norm(params, name): + s = 0 + for p in params: + s += getattr(p, name).pow(2).sum() + return s + + for n_epoch in range(nb_epochs_finished, nb_epochs): if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) @@ -850,32 +1060,92 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): - model.reset_inner_loss() - input = input.to(device) + memex_proba = ( + args.memex_proba + if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs + else 0.0 + ) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - inner_loss = model.get_inner_loss() + log_string(f"memex_proba {memex_proba}") - acc_train_loss += loss.item() * input.size(0) - acc_train_inner_loss += inner_loss.item() * input.size(0) + warnings.warn("memex v3", RuntimeWarning) + train_batches = add_memex_v3( + batches=task.batches(split="train"), + memex_proba=memex_proba, + marker_token=memex_marker, + ) - nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) + def add_none(it): + for x in it: + yield x + yield None + + nb_acc_samples = 0 + + for input in add_none(train_batches): + if input is not None: + if type(input) is tuple: + input, memex_mask = input + memex_mask = memex_mask.to(device) + else: + memex_mask = None + + model.reset_inner_loss() + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + if memex_mask is None: + loss = F.cross_entropy(output.transpose(1, 2), input) + else: + loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none") + loss_regular = (loss * (1 - memex_mask)).mean() + loss_memex = (loss * memex_mask).mean() - it += 1 - lr = get_lr(n_epoch, it) - for param_group in optimizer.param_groups: - param_group["lr"] = lr + if it < 100 or torch.rand(1) < 0.01: + update_ave_grad(loss_regular, model.parameters(), "grad_regular") + update_ave_grad(loss_memex, model.parameters(), "grad_memex") + norm_regular = norm(model.parameters(), "grad_regular") + norm_memex = norm(model.parameters(), "grad_memex") + l_memex = ( + max(norm_regular, norm_memex) - norm_regular + ) / norm_memex - # log_string(f"learning_rate {lr}") + loss = loss_regular + l_memex * loss_memex - optimizer.zero_grad() - total_loss.backward() - optimizer.step() + inner_loss = model.get_inner_loss() + + acc_train_loss += loss.item() * input.size(0) + acc_train_inner_loss += inner_loss.item() * input.size(0) + + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) + + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) + + it += 1 + lr = get_lr(n_epoch, it) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + # log_string(f"learning_rate {lr}") + + total_loss.backward() + nb_acc_samples += input.size(0) + + if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size: + assert nb_acc_samples <= args.batch_size + optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + lambda_file.write( + f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n" + ) + optimizer.zero_grad() + nb_acc_samples = 0 + n_batch += 1 with torch.autograd.no_grad(): model.eval() @@ -910,10 +1180,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" - ) + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) time_pred_result = time_current_result checkpoint = {