X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=a587e967af4fb9ec05fdca261f11decc55df8ac4;hb=9d4193312e06ed284b1368b7f4407f2b4f981c7a;hp=00b8301177379851e40b544df1ddec9163f469cd;hpb=408f2335af43590ee2d99c3286cbe3762c76887a;p=mygptrnn.git diff --git a/main.py b/main.py index 00b8301..a587e96 100755 --- a/main.py +++ b/main.py @@ -604,32 +604,46 @@ def add_memex_v3(batches, memex_proba, marker_token): t = torch.arange(input.size(1) + memex_len, device=input.device)[ None, : ].expand(input.size(0), -1) + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) # Call me the tensor-spaghetti master trigger = torch.rand(t.size(), device=t.device) - trigger[:, -memex_len:] = 1.0 - trigger = (trigger.sort(dim=1).indices == 0).long() + trigger[:, -memex_len:] = 2.0 + trigger[:, 0] = 2.0 + trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long() memex_mask = trigger.clone() - memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len] + memex_mask[:, memex_len:] -= trigger[:, :-memex_len] memex_mask = memex_mask.cumsum(dim=1) + u = 1 - memex_mask u[:, 0] = 0 u = u.cumsum(dim=1) - # assert u.min() == 0 - # assert u.max() == input.size(1) - 1 + assert u.min() == 0 + assert u.max() == input.size(1) - 1 + v = ( (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) - + torch.randint(input.size(1), (input.size(0), 1), device=t.device) + + torch.randint( + input.size(1) - memex_len, (input.size(0), 1), device=t.device + ) ) * memex_mask + assert v.min() >= 0 + assert v.max() < input.size(1) u = u * (1 - memex_mask) + v * memex_mask - n = torch.arange(input.size(0), device=input.device)[:, None].expand( - -1, t.size(1) - ) + new_input = input[n, u] + assert input.max() < vocabulary_size + assert new_input.max() < vocabulary_size limits = trigger.clone() limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] - new_input = new_input * (1 - limits) + memex_marker * limits + assert limits.min() == 0 + assert limits.max() == 1 + new_input = new_input * (1 - limits) + marker_token * limits + assert marker_token < vocabulary_size + assert new_input.max() < vocabulary_size yield new_input, memex_mask