X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=fridge;h=a4d860b73ac2f720363299030a75611f33110871;hb=26af4588b06ed463a4f9b9bcc4b527dd4c864d49;hp=143092cd4fb00e6e311d1276a9ead442e08992a1;hpb=8012a611e9920816fe6ba382b69305242136bc2a;p=mygptrnn.git diff --git a/fridge b/fridge index 143092c..a4d860b 100644 --- a/fridge +++ b/fridge @@ -335,3 +335,73 @@ class Calibrator: ) % k_star.size(0) k_star = k_star[l_barrel, t_barrel] + +###################################################################### + +2024 Feb 15 23:10:50 (from main.py) + + +def add_memex_v4(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u = torch.rand(t.size(), device=t.device) + u[:, : input.size(1)] = 1.0 + memex_v3_proba_fragment = 1 / 20 + u = (u < memex_v3_proba_fragment).long() + v = u * torch.randint(input.size(1), u.size()) + u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ + :, : input.size(1) - 1 + ] * input.size(1) + u = u.cumsum().clamp(min=0) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + + +###################################################################### + +2024 Feb 16 17:07:48 (from main.py) + + # ||gn + lambda * gm|| = max(||gn||,||gm||) + # ||gn||^2 + lambda + lambda^2||gm||^2 = max(||gn||^2,||gm||^2) + # A = ||gm||^2 B = C = ||gn||^2 - max(||gn||^2, ||gm||^2) + +###################################################################### + +2024 Feb 16 17:07:51 (from main.py) + + # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm) + # Delta = B*B - 4*A*C + # if(delta >= 0): + # l = ( -B - sqrt(Delta))/(2*A) + # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||)