From a3c32b845b6903fd290f2b09d5c53203ff112b79 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Feb 2024 06:50:44 +0100 Subject: [PATCH] Update. --- main.py | 28 ++++++- maxval.py | 108 ++++++++++++++++--------- mygpt.py | 233 ++++++++++++++++++++++++++---------------------------- 3 files changed, 209 insertions(+), 160 deletions(-) diff --git a/main.py b/main.py index ec50722..4d5077a 100755 --- a/main.py +++ b/main.py @@ -87,6 +87,8 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--proportion_memex", type=float, default=0) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -101,9 +103,9 @@ parser.add_argument("--caterpillar_height", type=int, default=None) parser.add_argument("--gate_dropout_proba", type=float, default=0.0) -parser.add_argument("--gate_dropout_sync", type=str2bool, default=True) +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) -parser.add_argument("--gate_dropout_replace", type=str2bool, default=True) +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) parser.add_argument("--rho_inner_loss", type=float, default=0.0) @@ -736,6 +738,9 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.proportion_memex > 0: + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -897,7 +902,24 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): + def add_memex(batches, proportion_memex): + for input in batches: + if torch.rand(1).item() < proportion_memex: + yield torch.cat( + [ + input, + torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device + ), + input, + ], + dim=1, + ) + yield input + + train_batches = add_memex(task.batches(split="train"), args.proportion_memex) + + for input in train_batches: model.reset_inner_loss() input = input.to(device) diff --git a/maxval.py b/maxval.py index a245287..99a7efb 100755 --- a/maxval.py +++ b/maxval.py @@ -5,17 +5,41 @@ import torch ###################################################################### -def baseline(X, V): +def baseline1(X, V): Y = X.new(X.size()) W = V.new(V.size()) + for t in range(X.size(1)): if t == 0: Y[:, t] = X[:, t] W[:, t] = V[:, t] else: - m = (V[:, t] >= W[:, t - 1] - 1).long() - Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1] - W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1) + m = (W[:, t - 1] - 1 >= V[:, t]).long() + W[:, t] = m * (W[:, t - 1] - 1) + (1 - m) * V[:, t] + Y[:, t] = m * Y[:, t - 1] + (1 - m) * ( + X[:, t] * (1 + dv) + Y[:, t - 1] * dv0 + ) + + return Y, W + + +###################################################################### + + +def hs(x): + return x.sigmoid() # (x >= 0).float() + (x - x.detach()) * (x < 0).float() + + +def baseline(X, V): + for t in range(X.size(1)): + if t == 0: + Y = X[:, t] + W = V[:, t] + else: + m = (W - 1 - V[:, t]).sigmoid() + # m = hs(W - 1 - V[:, t]) + W = m * (W - 1) + (1 - m) * V[:, t] + Y = m * Y + (1 - m) * X[:, t] return Y, W @@ -72,34 +96,36 @@ def pscan_diff(X, V, s=1): Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2) # [:, :, 0] < [:, :, 1] - dx = Xf[:, :, 1] - Xf[:, :, 1].detach() + dv0 = (Vf[:, :, 0] - Vf[:, :, 0].detach())[:, :, None] dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None] m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] m = m[:, :, None] - Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx) + Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + Xf[:, :, 0] * dv0) Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2) - Xr[:, 0] = X[:, 0] - Vr[:, 0] = V[:, 0] - # [:, :-1, 1] < [:, 1:, 0] - dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach() + dv0 = (Vrf[:, :-1, 1] - Vrf[:, :-1, 1].detach())[:, :, None] dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None] m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] m = m[:, :, None] - Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx) + Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * ( + Xf[:, 1:, 0] * (1 + dv) + Xrf[:, :-1, 1] * dv0 + ) + + Xr[:, 0] = X[:, 0] + Vr[:, 0] = V[:, 0] if T < X.size(1): # [:, -2] < [:, -1] - dx = X[:, -1] - X[:, -1].detach() + dx = X[:, -2] - X[:, -2].detach() dv = (V[:, -1] - V[:, -1].detach())[:, None] m = (V[:, -2] - s >= V[:, -1]).long() - Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] + Vr[:, -1] = m * (Vr[:, -2] - s) + (1 - m) * V[:, -1] m = m[:, None] - Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx) + Xr[:, -1] = m * Xr[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx) return Xr, Vr @@ -108,27 +134,35 @@ def pscan_diff(X, V, s=1): if __name__ == "__main__": N = 1 - T = 513 - D = 2 + T = 64 + D = 128 - X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - V = torch.rand(N, T, dtype=torch.float64) * 10 + torch.autograd.set_detect_anomaly(True) - X0, V0 = baseline(X, V) + for k in range(0): + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + V = torch.rand(N, T, dtype=torch.float64) - # print("########### X0 V0 ###########################################") - # print(V0) - # print(X0) + X0, V0 = baseline(X, V) - X1, V1 = pscan_diff(X, V) + # print("########### X0 V0 ###########################################") + # print(V0) + # print(X0) - # print("########### X V ############################################") - # print(V) - # print(X) + X1, V1 = pscan_diff(X, V) - print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()) + # print("########### X V ############################################") + # print(V) + # print(X) - exit(0) + error = ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item() + if error > 0: + print("ERROR", error) + print(X0) + print(X1) + exit(0) + + # exit(0) # s = X1.sum() # print(torch.autograd.grad(s, X)) @@ -138,18 +172,18 @@ if __name__ == "__main__": # f.write(f"{V1[0,t].item()}\n") Y = torch.randn(1, 1, D) - X = torch.randn( - N, T, D - ) # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y - V = torch.rand(N, T).requires_grad_() + X = torch.randn(N, T, D) * 0.1 + + m = (torch.rand(N, T, 1).sort(dim=1).indices == 0).float() + X = (1 - m) * X + m * Y + V = torch.rand(N, T) # + 100* m.squeeze(dim=-1) + V = V.requires_grad_() - optimizer = torch.optim.SGD([V], lr=1e-2) + optimizer = torch.optim.SGD([V], lr=1e-1) for k in range(1000): - X1, V1 = X.clone(), V.clone() - pscan(X, V, X1, V1) - # X1=X1*(1+V1-V1.detach())[:,:,None] - loss = (X1[:, -1:] - Y).pow(2).mean() + X1, V1 = baseline(X, V) + loss = (X1 - Y).pow(2).mean() print(k, loss.item()) optimizer.zero_grad() loss.backward() diff --git a/mygpt.py b/mygpt.py index 67c5cfd..c833012 100755 --- a/mygpt.py +++ b/mygpt.py @@ -502,13 +502,9 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.gate_dropout_proba = args.gate_dropout_proba - self.gate_dropout_sync = args.gate_dropout_sync - self.gate_dropout_replace = args.gate_dropout_replace - ###################################################################### - self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3) + self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) @@ -569,8 +565,6 @@ class Caterpillar(nn.Module): V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) - # V, K = blanket(V), blanket(K) - ###################################################################### # Compute the recurrent state @@ -589,81 +583,30 @@ class Caterpillar(nn.Module): G = G / G.sum(1, keepdim=True).clamp(min=1) - # G_star = (1 - G).log().sum(1, keepdim=True).exp() - ###################################################################### - def recurrence(G, V, K): - # We prepare the arguments for the parallel scan - - A = 1 - G.sum(dim=1) - - gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) - gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - - # We start from cached values, which matters in inference - - init_rec_V = self.rec_V[:, :, t0 - L : t0] - init_rec_K = self.rec_K[:, :, t0 - L : t0] - - # Here there is a trick: Since the stack at position t is - # computed by updating that at position t-L, the parallel - # scan operates with a period of L. To do so we split the - # sequence indexing in two axes, the second of size L, and - # run the parallel scan using the first as the sequence index. - - A = A.unflatten(2, (-1, L)) - gated_V = gated_V.unflatten(2, (-1, L)) - gated_K = gated_K.unflatten(2, (-1, L)) - - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) + A = 1 - G.sum(dim=1) - return next_V, next_K + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - ################################################################# + # We start from cached values, which matters in inference - next_V, next_K = recurrence(G, V, K) + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - if self.training and self.gate_dropout_proba > 0.0: - # G is NxHxRxT where r is the caterpillar's row. + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - warnings.warn("gate dropout", RuntimeWarning) + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - if self.gate_dropout_sync: - shape_kill = (N, 1, 1) - else: - shape_kill = (N, H, R) - - # Pick a point in each of the NxHxR timeline and set this - # entry and the following to 1 - kill = ( - torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices - == 0 - ).cumsum(dim=3) - - # Keep these mask for only some of the NxHxR - kill = kill * ( - torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba - ) - - # The coefficient to keep are the complementary - mask = 1 - kill - - masked_next_V, masked_next_K = recurrence(G * mask, V, K) - - if self.gate_dropout_replace: - next_V = next_V.detach() - next_K = next_K.detach() - - warnings.warn("the rescaling is probably a bad idea", RuntimeWarning) - - next_V = next_V + (masked_next_V - masked_next_V.detach()) / ( - 1 - self.gate_dropout_proba - ) - next_K = next_K + (masked_next_K - masked_next_K.detach()) / ( - 1 - self.gate_dropout_proba - ) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K @@ -710,10 +653,6 @@ class Caterpillar(nn.Module): windowed_V, ).flatten(2) - # Compute the final output - - # Y = blanket(Y) - self.cache_Y[:, t0:t1] = Y @ self.w_O return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) @@ -730,6 +669,7 @@ class QKVAttention(nn.Module): dim_v, nb_heads=1, causal=False, + horizon=None, attention_dropout=0.0, logger=print, args=None, @@ -740,6 +680,7 @@ class QKVAttention(nn.Module): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) self.causal = causal + self.horizon = horizon self.attention_dropout = attention_dropout self.record_attention = False @@ -783,6 +724,17 @@ class QKVAttention(nn.Module): torch.arange(x_q.size(1), device=q.device)[None, None, :, None] < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] ) + + if self.horizon is not None: + self.cache_attzero = torch.logical_or( + self.cache_attzero, + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + >= torch.arange(x_q.size(1), device=q.device)[ + None, None, None, : + ] + + self.horizon, + ) + a = a.masked_fill( self.cache_attzero[ :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb @@ -834,9 +786,10 @@ class MyGPT(nn.Module): "dumbrec", "kvrec", "caterpillar", + "attcat", }, f"Unknown attention operator {attention_layer}." - if attention_layer == "caterpillar": + if attention_layer == "caterpillar" or attention_layer == "attcat": assert nb_lines % caterpillar_height == 0 self.caterpillar_length = nb_lines // caterpillar_height self.caterpillar_height = caterpillar_height @@ -855,59 +808,99 @@ class MyGPT(nn.Module): def attlayer(): if attention_layer == "mha": - return QKVAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - causal=causal, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "dumbrec": - return DumbRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + DumbRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "kvrec": - return KVRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + KVRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "caterpillar": - return Caterpillar( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - caterpillar_length=self.caterpillar_length, - caterpillar_height=self.caterpillar_height, - attention_dropout=dropout, - logger=logger, - args=args, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ) + elif attention_layer == "attcat": + return nn.Sequential( + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + horizon=self.caterpillar_length, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), ) else: raise ValueError(f"Unknown attention type {attention_layer}.") for b in range(nb_blocks): trunk_blocks += [ - WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), - attlayer(), - ), + attlayer(), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), -- 2.20.1