X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=12b3631596b702e5f7d4f5e91d797ac4d6487298;hb=HEAD;hp=aded7967a4c8c0f4d84fdfd39085929b2e42c291;hpb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index aded796..12b3631 100755 --- a/mygpt.py +++ b/mygpt.py @@ -21,6 +21,8 @@ from torch.nn import functional as F import ffutils +# from blanket import blanket + # import memload ###################################################################### @@ -84,6 +86,18 @@ class CacheWrapper(nn.Module): ############################## +class NaNChecker(nn.Module): + def __init__(self, name): + super().__init__() + self.name = name + + def forward(self, bs): + x = bs.x if type(bs) is BracketedSequence else bs + assert not x.isnan().any(), f"${self.name} detected NaN" + assert not x.isinf().any(), f"${self.name} detected Inf" + return bs + + class WithResidual(nn.Module): def __init__(self, *f): super().__init__() @@ -126,7 +140,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -147,6 +160,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -191,7 +216,7 @@ class DumbRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -205,19 +230,9 @@ class DumbRec(nn.Module): self.w_qw = randw(nb_heads, dim_qk, dim_model) self.w_qr = randw(nb_heads, dim_qk, dim_model) - # self.w_k = randw(nb_heads, dim_qk, dim_model) self.w_v = randw(nb_heads, dim_v, dim_model) self.w_o = randw(dim_v * nb_heads, dim_model) - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 - - def get_inner_loss(self): - warnings.warn("l2 regularization", RuntimeWarning) - return (self.acc_attention / self.acc_nb).pow(2).sum() - # return torch.tensor([0], device=self.w_qw.device) - def forward(self, bs): x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb @@ -225,61 +240,33 @@ class DumbRec(nn.Module): self.rec_v = x_q.new_zeros( x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) ) - # self.rec_k = x_q.new_zeros( - # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) - # ) self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) - ###################################################################### - # Prepare the keys - - k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) - - warnings.warn("rotating key barrel", RuntimeWarning) - k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) - t_barrel = torch.arange(t0, t1, device=k_star.device) - t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) - l_barrel = ( - torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel - ) % k_star.size(0) - k_star = k_star[l_barrel, t_barrel] - ###################################################################### # Compute the recurrent state qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) - # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) - aw = torch.einsum( - "nhtd,ltd->nhlt", - qw, - k_star, - ) / math.sqrt(self.w_qw.size(1)) + aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt( + self.w_qw.size(1) + ) aw = aw.softmax(dim=2) # nhlt - if self.train: - self.acc_attention += aw.sum(dim=(0, 1, 3)) - self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) - aw = F.dropout(aw, self.attention_dropout, self.training) A = 1 - aw.sum(dim=1) # nlt V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() - # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() if t0 == 0: V0 = None - # K0 = None else: V0 = self.rec_v[:, :, t0 - 1] - # K0 = self.rec_k[:, :, t0 - 1] self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) - # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) ###################################################################### # compute the readout @@ -289,7 +276,6 @@ class DumbRec(nn.Module): ar = torch.einsum( "nhtd,ld->nhlt", qr, - # self.rec_k[:, :, t0:t1], self.k_star, ) / math.sqrt(self.w_qr.size(1)) @@ -322,7 +308,7 @@ class KVRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -345,9 +331,9 @@ class KVRec(nn.Module): self.acc_nb = 0 def get_inner_loss(self): - warnings.warn("l2 regularization", RuntimeWarning) - return (self.acc_attention / self.acc_nb).pow(2).sum() - # return torch.tensor([0], device=self.w_qw.device) + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + return torch.tensor([0], device=self.w_qw.device) # warnings.warn("side regularization", RuntimeWarning) # return ( # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum() @@ -371,12 +357,12 @@ class KVRec(nn.Module): k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) - warnings.warn("rotating key barrel", RuntimeWarning) + # warnings.warn("rotating key barrel", RuntimeWarning) k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) t_barrel = torch.arange(t0, t1, device=k_star.device) t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) l_barrel = ( - torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + torch.arange(k_star.size(0), device=k_star.device)[:, None] # + t_barrel ) % k_star.size(0) k_star = k_star[l_barrel, t_barrel] @@ -464,36 +450,6 @@ def moving_window(x, dim, win_dim, win_size): ############################## -class Calibrator: - def __init__(self, w=None, b=None): - self.w = w - self.b = b - self.s, self.s_sq, self.n = 0, 0, 0 - self.mean, self.std = 0, 0 - - def update(self, X): - X = X.detach() - self.s += X.sum(dim=0) - self.s_sq += X.pow(2).sum(dim=0) - self.n += X.size(0) - - def moments(self): - mean = self.s / self.n - std = (self.s_sq / self.n - mean * mean).sqrt() - return mean, std - - def normalize(self): - mean, std = self.moments() - if self.b is not None: - self.b.sub_(mean) - if self.w is not None: - self.w.div_(std) - result = mean - self.mean, std - self.std - self.mean, self.std = mean, std - self.s, self.s_sq, self.n = 0, 0, 0 - return result - - class Caterpillar(nn.Module): def __init__( self, @@ -506,44 +462,23 @@ class Caterpillar(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d, amplitude=None): - if amplitude is None: - amplitude = 1 / math.sqrt(d[-1]) - return nn.Parameter(amplitude * torch.randn(*d)) + def randw(*d, factor=1): + return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1])) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - ###################################################################### - # sup_args - - x = kwargs.get("gate_dropout") - if x is None: - self.proba_gate_dropout = 0.0 - else: - self.proba_gate_dropout = float(x) - - logger(f"self.proba_gate_dropout {self.proba_gate_dropout}") - - x = kwargs.get("default_bg") - if x is None: - default_bg = -math.log(caterpillar_height - 1) - else: - default_bg = float(x) - - logger(f"default_bg {default_bg}") - ###################################################################### self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -561,18 +496,14 @@ class Caterpillar(nn.Module): dim_v, ) - self.calibrator_G = Calibrator() - self.calibrator_rec_V = Calibrator() - self.calibrator_rec_K = Calibrator() + # def reset_inner_loss(self): + # self.acc_attention = 0 + # self.acc_nb = 0 - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 - - def get_inner_loss(self): - # warnings.warn("l2 regularization", RuntimeWarning) - # return (self.acc_attention / self.acc_nb).pow(2).sum() - return torch.tensor([0], device=self.w_Q.device) + # def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_Q.device) def forward(self, bs): # Dimensions to make the source a bit clearer, that's needed @@ -620,56 +551,14 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - self.calibrator_G.update(G.reshape(-1, G.size(-1))) - - # warnings.warn("softmax gating", RuntimeWarning) - - # G = ( - # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] - # ).softmax(dim=2) - - ###################################################################### - # The "flashbacks" - - if self.training and self.proba_gate_dropout > 0.0: - # This is a better implementation of "flashbacks". - - # G is NxHxExT where e is the caterpillar's row. - - warnings.warn("gate dropout", RuntimeWarning) - - kill = ( - torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - ).float() - - alpha = G / (1 - self.proba_gate_dropout) - - G = alpha * (1 - kill) - - ###################################################################### # Clip the gating to avoid values greater than 1 when several # heads hit the same row G = G / G.sum(1, keepdim=True).clamp(min=1) ###################################################################### - # Roll the gating indexes - - # warnings.warn("rotating barrel", RuntimeWarning) - - # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - # r_barrel = (r_barrel + (t_barrel + t0) // L) % R - # G = G.gather(dim=2, index=r_barrel.expand_as(G)) - # We prepare the arguments for the parallel scan - - A = 1 - G.sum(1) - - # warnings.warn("harmonic recurrence", RuntimeWarning) - # har = torch.arange(t0, t1, device = G.device).float() + 1 - # A = har / (har + 1) - # G = G / har + A = 1 - G.sum(dim=1) gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) @@ -679,9 +568,6 @@ class Caterpillar(nn.Module): init_rec_V = self.rec_V[:, :, t0 - L : t0] init_rec_K = self.rec_K[:, :, t0 - L : t0] - ################################################################# - # Associative scan - # Here there is a trick: Since the stack at position t is # computed by updating that at position t-L, the parallel # scan operates with a period of L. To do so we split the @@ -692,18 +578,8 @@ class Caterpillar(nn.Module): gated_V = gated_V.unflatten(2, (-1, L)) gated_K = gated_K.unflatten(2, (-1, L)) - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - - next_V = next_V.flatten(2, 3) - next_K = next_K.flatten(2, 3) - - self.calibrator_rec_V.update( - next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2)) - ) - self.calibrator_rec_K.update( - next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2)) - ) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K @@ -713,8 +589,10 @@ class Caterpillar(nn.Module): Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - # We build tensors NxHxTxFxL where N is the sample index, H - # the head, T the time, F the row in the caterpillar, and L + # Q = blanket(Q) + + # We build tensors NxHxTxRxL where N is the sample index, H + # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar windowed_V = moving_window( @@ -728,7 +606,7 @@ class Caterpillar(nn.Module): # We have an attention score for each of the RxL values ar = torch.einsum( - "nhtd,nftld->nhtfl", + "nhtd,nrtld->nhtrl", Q, windowed_K, ) / math.sqrt(DK) @@ -748,8 +626,6 @@ class Caterpillar(nn.Module): windowed_V, ).flatten(2) - # Compute the final output - self.cache_Y[:, t0:t1] = Y @ self.w_O return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) @@ -766,9 +642,10 @@ class QKVAttention(nn.Module): dim_v, nb_heads=1, causal=False, + horizon=None, attention_dropout=0.0, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -776,6 +653,7 @@ class QKVAttention(nn.Module): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) self.causal = causal + self.horizon = horizon self.attention_dropout = attention_dropout self.record_attention = False @@ -819,6 +697,17 @@ class QKVAttention(nn.Module): torch.arange(x_q.size(1), device=q.device)[None, None, :, None] < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] ) + + if self.horizon is not None: + self.cache_attzero = torch.logical_or( + self.cache_attzero, + torch.arange(x_q.size(1), device=q.device)[None, None, :, None] + >= torch.arange(x_q.size(1), device=q.device)[ + None, None, None, : + ] + + self.horizon, + ) + a = a.masked_fill( self.cache_attzero[ :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb @@ -859,20 +748,23 @@ class MyGPT(nn.Module): causal=False, dropout=0.0, len_max=1e5, - attention_layer="kvrec", + attention_layer="caterpillar", logger=print, - **kwargs, + args=None, ): super().__init__() + self.vocabulary_size = vocabulary_size + assert attention_layer in { "mha", "dumbrec", "kvrec", "caterpillar", + "attcat", }, f"Unknown attention operator {attention_layer}." - if attention_layer == "caterpillar": + if attention_layer == "caterpillar" or attention_layer == "attcat": assert nb_lines % caterpillar_height == 0 self.caterpillar_length = nb_lines // caterpillar_height self.caterpillar_height = caterpillar_height @@ -891,59 +783,99 @@ class MyGPT(nn.Module): def attlayer(): if attention_layer == "mha": - return QKVAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - causal=causal, - attention_dropout=dropout, - logger=logger, - **kwargs, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "dumbrec": - return DumbRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - **kwargs, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + DumbRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "kvrec": - return KVRec( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - nb_lines=nb_lines, - attention_dropout=dropout, - logger=logger, - **kwargs, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + KVRec( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + nb_lines=nb_lines, + attention_dropout=dropout, + logger=logger, + args=args, + ), ) elif attention_layer == "caterpillar": - return Caterpillar( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - caterpillar_length=self.caterpillar_length, - caterpillar_height=self.caterpillar_height, - attention_dropout=dropout, - logger=logger, - **kwargs, + return WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ) + elif attention_layer == "attcat": + return nn.Sequential( + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=causal, + horizon=self.caterpillar_length, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), + WithResidual( + CacheWrapper(nn.LayerNorm((dim_model,))), + Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=self.caterpillar_length, + caterpillar_height=self.caterpillar_height, + attention_dropout=dropout, + logger=logger, + args=args, + ), + ), ) else: raise ValueError(f"Unknown attention type {attention_layer}.") for b in range(nb_blocks): trunk_blocks += [ - WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), - attlayer(), - ), + attlayer(), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), @@ -1064,7 +996,115 @@ class MyGPT(nn.Module): ###################################################################### if __name__ == "__main__": - print("Basic check.") + import argparse + + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.collections as mc + + args = argparse.Namespace( + gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dim_model, dim_keys, nb_heads = 512, 64, 1 + dropout = 0.1 + + caterpillar = Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=16, + caterpillar_height=32, + attention_dropout=dropout, + args=args, + ).to(device) + + qkv = QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=True, + attention_dropout=dropout, + args=args, + ).to(device) + + linear = CacheWrapper(nn.Linear(512, 512)).to(device) + + x = torch.randn(1, 256, dim_model) + + x = x.to(device) + x.requires_grad_() + + ###################################################################### + + fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(8) + + ax = fig.add_subplot(1, 1, 1) + + # ax.set_xlim(-1.5, 1.5) + # ax.set_ylim(-1.5, 1.5) + # ax.set(aspect=1) + # ax.spines.right.set_visible(False) + # ax.spines.top.set_visible(False) + + # dt = 0.01 + # t = np.arange(dt, 20.0, dt) + # ax.semilogx(t, np.exp(-t / 5.0)) + # ax.grid() + ax.set_yscale("log") + + ###################################################################### + + for label, model, thickness in [ + ("nn.Linear", linear, 0.2), + ("mygpy.QKVAttention", qkv, 1), + ("mygpt.Caterpillar", caterpillar, 2), + ]: + y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x + + for n, p in [("input", x)] + list(model.named_parameters()): + print(f"Processing {model}.{n}") + data = [] + for t in range(y.size(1)): + sg = 0 + for d in torch.randperm(y.size(2))[:8]: + sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0] + assert not sg.isinf().any() + assert not sg.isnan().any() + data.append([t, sg.sum().item()]) + + data = torch.tensor(data) + # cx, cy = data[:, 0], data[:, 1] + cy = data[:, 1].sort().values + cx = torch.linspace(0, 1, cy.size(0)) + ax.plot( + cx, cy, label=label + "." + n, linewidth=thickness + ) # , color='gray', label='Input') + + # ax.legend(frameon=False, loc="top right") + + # Put a legend to the right of the current axis + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + filename = "plot.pdf" + print(f"saving {filename}") + fig.savefig(filename, bbox_inches="tight") + + # if args.window and hasattr(plt.get_current_fig_manager(), 'window'): + # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + # plt.show() + + exit(0) + + ###################################################################### m = Caterpillar( dim_model=4, @@ -1086,8 +1126,6 @@ if __name__ == "__main__": print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max()) exit(0) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - vocabulary_size = 128 x = torch.randint(vocabulary_size, (6, 1024))