From 408f2335af43590ee2d99c3286cbe3762c76887a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 20 Feb 2024 08:52:38 +0100 Subject: [PATCH] Update. --- main.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index 2a90fd1..00b8301 100755 --- a/main.py +++ b/main.py @@ -570,7 +570,8 @@ def add_memex_v1(batches, memex_proba, marker_token): yield input -def add_memex_v2(batches, memex_proba): +# The marker token is not used for this one +def add_memex_v2(batches, memex_proba, marker_token): for input in batches: if torch.rand(1).item() < memex_proba: t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand( @@ -595,6 +596,47 @@ def add_memex_v2(batches, memex_proba): yield input +def add_memex_v3(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + memex_len = input.size(1) // 4 + + t = torch.arange(input.size(1) + memex_len, device=input.device)[ + None, : + ].expand(input.size(0), -1) + + # Call me the tensor-spaghetti master + + trigger = torch.rand(t.size(), device=t.device) + trigger[:, -memex_len:] = 1.0 + trigger = (trigger.sort(dim=1).indices == 0).long() + memex_mask = trigger.clone() + memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len] + memex_mask = memex_mask.cumsum(dim=1) + u = 1 - memex_mask + u[:, 0] = 0 + u = u.cumsum(dim=1) + # assert u.min() == 0 + # assert u.max() == input.size(1) - 1 + v = ( + (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) + + torch.randint(input.size(1), (input.size(0), 1), device=t.device) + ) * memex_mask + u = u * (1 - memex_mask) + v * memex_mask + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + new_input = input[n, u] + limits = trigger.clone() + limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] + new_input = new_input * (1 - limits) + memex_marker * limits + + yield new_input, memex_mask + + else: + yield input + + ###################################################################### assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} @@ -814,6 +856,7 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() if args.memex_proba > 0: + memex_marker = vocabulary_size vocabulary_size += 1 log_string(f"vocabulary_size {vocabulary_size}") @@ -975,7 +1018,18 @@ def the_dot_products(value1, value2, params): return torch.cat([g1g1, g1g2, g2g2]) -movave_dot_products = 0 +def update_ave_grad(value, params, name, eps=1e-3): + for p in params: + g = torch.autograd.grad(value, p, retain_graph=True)[0] + ag = getattr(p, name) if hasattr(p, name) else 0 + setattr(p, name, (1 - eps) * ag + eps * g) + + +def norm(params, name): + s = 0 + for p in params: + s += getattr(p, name).pow(2).sum() + return s for n_epoch in range(nb_epochs_finished, nb_epochs): @@ -1000,9 +1054,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): log_string(f"memex_proba {memex_proba}") - train_batches = add_memex_v2( + warnings.warn("memex v3", RuntimeWarning) + train_batches = add_memex_v3( batches=task.batches(split="train"), memex_proba=memex_proba, + marker_token=memex_marker, ) def add_none(it): @@ -1032,18 +1088,16 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): loss_regular = (loss * (1 - memex_mask)).mean() loss_memex = (loss * memex_mask).mean() - if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01: - dot_products = the_dot_products( - loss_regular, loss_memex, model.parameters() - ) - eps = 1e-3 - movave_dot_products = ( - 1 - eps - ) * movave_dot_products + eps * dot_products + if it < 100 or torch.rand(1) < 0.01: + update_ave_grad(loss_regular, model.parameters(), "grad_regular") + update_ave_grad(loss_memex, model.parameters(), "grad_memex") + norm_regular = norm(model.parameters(), "grad_regular") + norm_memex = norm(model.parameters(), "grad_memex") + l_memex = ( + max(norm_regular, norm_memex) - norm_regular + ) / norm_memex - grgr, grgm, gmgm = movave_dot_products - l = (max(grgr, gmgm) - grgr) / gmgm - loss = loss_regular + l * loss_memex + loss = loss_regular + l_memex * loss_memex inner_loss = model.get_inner_loss() @@ -1072,9 +1126,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): optimizer.step() grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") - grgr, grgm, gmgm = movave_dot_products - l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm) - lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n") + lambda_file.write( + f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n" + ) optimizer.zero_grad() nb_acc_samples = 0 n_batch += 1 -- 2.20.1