Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # from blanket import blanket
25
26 # import memload
27
28 ######################################################################
29
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # steps to compute.
32
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
36 # no holes.
37 #
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
42 #
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
45
46
47 class BracketedSequence:
48     def __init__(self, x, first=None, nb=None, init_cache=None):
49         self.x = x
50         assert (first is None and nb is None and init_cache is None) or (
51             first is not None and nb is not None and init_cache is not None
52         )
53
54         self.first = 0 if first is None else first
55         self.nb = x.size(1) if nb is None else nb
56         self.init_cache = True if init_cache is None else init_cache
57
58     def slice(self):
59         return self.x[:, self.first : self.first + self.nb]
60
61     def complete(self):
62         return self.first == 0 and self.nb == self.x.size(1)
63
64
65 ######################################################################
66
67
68 class CacheWrapper(nn.Module):
69     def __init__(self, *f):
70         super().__init__()
71         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
72
73     def forward(self, bs):
74         if bs.init_cache:
75             y = self.f(bs.slice())
76             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77             self.cache_y[:, bs.first : bs.first + bs.nb] = y
78         else:
79             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80             assert bs.first + bs.nb <= self.cache_y.size(1)
81             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
82
83         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84
85
86 ##############################
87
88
89 class NaNChecker(nn.Module):
90     def __init__(self, name):
91         super().__init__()
92         self.name = name
93
94     def forward(self, bs):
95         x = bs.x if type(bs) is BracketedSequence else bs
96         assert not x.isnan().any(), f"${self.name} detected NaN"
97         assert not x.isinf().any(), f"${self.name} detected Inf"
98         return bs
99
100
101 class WithResidual(nn.Module):
102     def __init__(self, *f):
103         super().__init__()
104         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
105
106     def forward(self, bs):
107         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
108
109
110 ##############################
111
112
113 class AddPositionalEncoding(nn.Module):
114     def __init__(self, len_max):
115         super().__init__()
116         self.len_max = len_max
117
118     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
119
120     def forward(self, bs):
121         if bs.init_cache:
122             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
123                 :, None
124             ]
125             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
126                 None, :
127             ]
128             k = j % 2
129             self.pe = torch.sin(
130                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
131             )
132             self.cache_y = bs.x.new(bs.x.size())
133
134         self.cache_y[:, bs.first : bs.first + bs.nb] = (
135             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
136         )
137
138         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
139
140
141 import pscan
142
143 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
144
145
146 def pscan_dim(A, X, Y_init, dim=-2):
147     s = X.size()
148     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
149
150     A = A.reshape(a, T, *s[dim + 1 : -1])
151     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
152
153     if Y_init is None:
154         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
155     else:
156         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
164     with torch.no_grad():
165         s_A, s_X = 0, 0
166         for t in range(X.size(dim) - 1, 0, -1):
167             delta = (grad_Y[t] - s_A) / A[t].grad
168             s_A += A[t].grad * delta
169             A[t].grad = delta
170             delta = (grad_Y[t] - s_X) / X[t].grad
171             s_X += X[t].grad * delta
172             X[t].grad = delta
173
174
175 def pscan_shape(A, X, Y_init):
176     s = X.size()
177     A = A.reshape(-1, s[-2])
178     X = X.reshape(-1, s[-2], s[-1])
179
180     if Y_init is None:
181         Y_init = X.new_zeros(X.size(0), s[-1])
182     else:
183         Y_init = Y_init.reshape(-1, s[-1])
184
185     Y = pscan.pscan(A, X, Y_init).reshape(s)
186
187     return Y
188
189
190 def nsum_shape(X, Y_init):
191     s = X.size()
192     X = X.reshape(-1, s[-2], s[-1])  # ntd
193
194     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
195     result = []
196
197     for k in range(X.size(1)):
198         Y = Y + X[:, k]
199         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
200         result.append(Y)
201
202     return torch.cat(result, dim=1).reshape(s)
203
204
205 ##############################
206
207
208 class DumbRec(nn.Module):
209     def __init__(
210         self,
211         dim_model,
212         dim_qk,
213         dim_v,
214         nb_heads,
215         nb_lines,
216         attention_dropout=0.0,
217         len_max=1e5,
218         logger=print,
219         args=None,
220     ):
221         super().__init__()
222
223         def randw(*d):
224             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
225
226         self.nb_lines = nb_lines
227         self.attention_dropout = attention_dropout
228
229         self.k_star = randw(nb_lines, dim_qk)
230
231         self.w_qw = randw(nb_heads, dim_qk, dim_model)
232         self.w_qr = randw(nb_heads, dim_qk, dim_model)
233         self.w_v = randw(nb_heads, dim_v, dim_model)
234         self.w_o = randw(dim_v * nb_heads, dim_model)
235
236     def forward(self, bs):
237         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
238
239         if bs.init_cache:
240             self.rec_v = x_q.new_zeros(
241                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
242             )
243             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251
252         aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
253             self.w_qw.size(1)
254         )
255
256         aw = aw.softmax(dim=2)  # nhlt
257
258         aw = F.dropout(aw, self.attention_dropout, self.training)
259
260         A = 1 - aw.sum(dim=1)  # nlt
261
262         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
263
264         if t0 == 0:
265             V0 = None
266         else:
267             V0 = self.rec_v[:, :, t0 - 1]
268
269         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
270
271         ######################################################################
272         # compute the readout
273
274         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
275
276         ar = torch.einsum(
277             "nhtd,ld->nhlt",
278             qr,
279             self.k_star,
280         ) / math.sqrt(self.w_qr.size(1))
281
282         ar = ar.softmax(dim=2)  # nhlt
283
284         ar = F.dropout(ar, self.attention_dropout, self.training)
285
286         y = torch.einsum(
287             "nhlt,nltd->nthd",
288             ar,
289             self.rec_v[:, :, t0:t1],
290         ).flatten(2)
291
292         self.cache_y[:, t0:t1] = y @ self.w_o
293
294         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
295
296
297 ##############################
298
299
300 class KVRec(nn.Module):
301     def __init__(
302         self,
303         dim_model,
304         dim_qk,
305         dim_v,
306         nb_heads,
307         nb_lines,
308         attention_dropout=0.0,
309         len_max=1e5,
310         logger=print,
311         args=None,
312     ):
313         super().__init__()
314
315         def randw(*d):
316             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
317
318         self.nb_lines = nb_lines
319         self.attention_dropout = attention_dropout
320
321         self.k_star = randw(nb_lines, dim_qk)
322
323         self.w_qw = randw(nb_heads, dim_qk, dim_model)
324         self.w_qr = randw(nb_heads, dim_qk, dim_model)
325         self.w_k = randw(nb_heads, dim_qk, dim_model)
326         self.w_v = randw(nb_heads, dim_v, dim_model)
327         self.w_o = randw(dim_v * nb_heads, dim_model)
328
329     def reset_inner_loss(self):
330         self.acc_attention = 0
331         self.acc_nb = 0
332
333     def get_inner_loss(self):
334         # warnings.warn("l2 regularization", RuntimeWarning)
335         # return (self.acc_attention / self.acc_nb).pow(2).sum()
336         return torch.tensor([0], device=self.w_qw.device)
337         # warnings.warn("side regularization", RuntimeWarning)
338         # return (
339         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
340         # )
341         # return torch.tensor([0], device=self.w_qw.device)
342
343     def forward(self, bs):
344         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
345
346         if bs.init_cache:
347             self.rec_v = x_q.new_zeros(
348                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
349             )
350             self.rec_k = x_q.new_zeros(
351                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
352             )
353             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
354
355         ######################################################################
356         # Prepare the keys
357
358         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
359
360         # warnings.warn("rotating key barrel", RuntimeWarning)
361         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
362         t_barrel = torch.arange(t0, t1, device=k_star.device)
363         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
364         l_barrel = (
365             torch.arange(k_star.size(0), device=k_star.device)[:, None]  # + t_barrel
366         ) % k_star.size(0)
367         k_star = k_star[l_barrel, t_barrel]
368
369         ######################################################################
370         # Compute the recurrent state
371
372         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
373
374         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
375         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
376
377         aw = torch.einsum(
378             "nhtd,ltd->nhlt",
379             qw,
380             k_star,
381         ) / math.sqrt(self.w_qw.size(1))
382
383         aw = aw.softmax(dim=2)  # nhlt
384
385         if self.train:
386             # We want all the memory lines to be used similarly
387             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
388             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
389
390         aw = F.dropout(aw, self.attention_dropout, self.training)
391
392         A = 1 - aw.sum(dim=1)  # nlt
393
394         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
395         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
396
397         if t0 == 0:
398             V0 = None
399             K0 = None
400         else:
401             V0 = self.rec_v[:, :, t0 - 1]
402             K0 = self.rec_k[:, :, t0 - 1]
403
404         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
405         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
406
407         ######################################################################
408         # compute the readout
409
410         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
411
412         ar = torch.einsum(
413             "nhtd,nltd->nhlt",
414             qr,
415             self.rec_k[:, :, t0:t1],
416         ) / math.sqrt(self.w_qr.size(1))
417
418         ar = ar.softmax(dim=2)  # nhlt
419
420         ar = F.dropout(ar, self.attention_dropout, self.training)
421
422         y = torch.einsum(
423             "nhlt,nltd->nthd",
424             ar,
425             self.rec_v[:, :, t0:t1],
426         ).flatten(2)
427
428         self.cache_y[:, t0:t1] = y @ self.w_o
429
430         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
431
432
433 ##############################
434
435
436 # Returns a tensor with an additional index at rank win_dim, that move
437 # along the same dimension as dim, on a domain {0...win_size-1}, and
438 # dim is restricted on a domain reduced by win_size-1 values.
439
440
441 def moving_window(x, dim, win_dim, win_size):
442     size, stride = x.size(), x.stride()
443     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
444     size = size[:win_dim] + (win_size,) + size[win_dim:]
445     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
446
447     return x.as_strided(size=size, stride=stride)
448
449
450 ##############################
451
452
453 class Caterpillar(nn.Module):
454     def __init__(
455         self,
456         dim_model,
457         dim_qk,
458         dim_v,
459         nb_heads,
460         caterpillar_length,
461         caterpillar_height,
462         attention_dropout=0.0,
463         len_max=1e5,
464         logger=print,
465         args=None,
466     ):
467         super().__init__()
468
469         warnings.warn("Caterpillar", RuntimeWarning)
470
471         def randw(*d, factor=1):
472             return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
473
474         self.caterpillar_length = caterpillar_length
475         self.caterpillar_height = caterpillar_height
476         self.attention_dropout = attention_dropout
477
478         ######################################################################
479
480         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
481         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
482
483         self.w_K = randw(nb_heads, dim_qk, dim_model)
484         self.w_V = randw(nb_heads, dim_v, dim_model)
485         self.w_Q = randw(nb_heads, dim_qk, dim_model)
486         self.w_O = randw(dim_v * nb_heads, dim_model)
487
488         self.init_K_rec = randw(
489             caterpillar_height,
490             caterpillar_length,
491             dim_qk,
492         )
493         self.init_V_rec = randw(
494             caterpillar_height,
495             caterpillar_length,
496             dim_v,
497         )
498
499     # def reset_inner_loss(self):
500     # self.acc_attention = 0
501     # self.acc_nb = 0
502
503     # def get_inner_loss(self):
504     # warnings.warn("l2 regularization", RuntimeWarning)
505     # return (self.acc_attention / self.acc_nb).pow(2).sum()
506     # return torch.tensor([0], device=self.w_Q.device)
507
508     def forward(self, bs):
509         # Dimensions to make the source a bit clearer, that's needed
510
511         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
512
513         N = bs.x.size(0)
514         T = bs.x.size(1)
515         H = self.w_V.size(0)
516         DV = self.w_V.size(1)
517         DK = self.w_K.size(1)
518         DM = self.w_O.size(1)
519         R = self.caterpillar_height
520         L = self.caterpillar_length
521
522         assert (
523             t0 >= L and (t1 - t0) % L == 0
524         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
525
526         # We cache values to deal efficiently with auto-regression
527
528         if bs.init_cache:
529             self.rec_V = X.new_zeros(N, R, T, DV)
530             self.rec_K = X.new_zeros(N, R, T, DK)
531             # We start the recurrent sequences with optimizable
532             # initial values. No idea if it helps.
533             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
534             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
535
536             self.cache_Y = X.new_zeros(N, T, DM)
537
538         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
539         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
540
541         ######################################################################
542         # Compute the recurrent state
543
544         # This is the Gating sequence that modulates the storing of
545         # the new key and value in the R pairs of the current
546         # stack. There are R independent gating values, which means
547         # that the current K/V may be stored in multiple pairs of the
548         # recurrent state, or not at all.
549
550         G = (
551             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
552         ).sigmoid()
553
554         # Clip the gating to avoid values greater than 1 when several
555         # heads hit the same row
556
557         G = G / G.sum(1, keepdim=True).clamp(min=1)
558
559         ######################################################################
560
561         A = 1 - G.sum(dim=1)
562
563         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
564         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
565
566         # We start from cached values, which matters in inference
567
568         init_rec_V = self.rec_V[:, :, t0 - L : t0]
569         init_rec_K = self.rec_K[:, :, t0 - L : t0]
570
571         # Here there is a trick: Since the stack at position t is
572         # computed by updating that at position t-L, the parallel
573         # scan operates with a period of L. To do so we split the
574         # sequence indexing in two axes, the second of size L, and
575         # run the parallel scan using the first as the sequence index.
576
577         A = A.unflatten(2, (-1, L))
578         gated_V = gated_V.unflatten(2, (-1, L))
579         gated_K = gated_K.unflatten(2, (-1, L))
580
581         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
582         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
583
584         self.rec_V[:, :, t0:t1] = next_V
585         self.rec_K[:, :, t0:t1] = next_K
586
587         ######################################################################
588         # compute the readout
589
590         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
591
592         # Q = blanket(Q)
593
594         # We build tensors NxHxTxRxL where N is the sample index, H
595         # the head, T the time, R the row in the caterpillar, and L
596         # the column in the caterpillar
597
598         windowed_V = moving_window(
599             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
600         )
601
602         windowed_K = moving_window(
603             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
604         )
605
606         # We have an attention score for each of the RxL values
607
608         ar = torch.einsum(
609             "nhtd,nrtld->nhtrl",
610             Q,
611             windowed_K,
612         ) / math.sqrt(DK)
613
614         # softmax can operate only on one dimension, hence the
615         # flattening
616
617         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
618
619         ar = F.dropout(ar, self.attention_dropout, self.training)
620
621         # Compute the output for each head, flatten to concatenate
622
623         Y = torch.einsum(
624             "nhtfl,nftld->nthd",
625             ar,
626             windowed_V,
627         ).flatten(2)
628
629         self.cache_Y[:, t0:t1] = Y @ self.w_O
630
631         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
632
633
634 ##############################
635
636
637 class QKVAttention(nn.Module):
638     def __init__(
639         self,
640         dim_model,
641         dim_qk,
642         dim_v,
643         nb_heads=1,
644         causal=False,
645         horizon=None,
646         attention_dropout=0.0,
647         logger=print,
648         args=None,
649     ):
650         super().__init__()
651
652         def randw(*d):
653             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
654
655         self.causal = causal
656         self.horizon = horizon
657         self.attention_dropout = attention_dropout
658         self.record_attention = False
659
660         self.w_q = randw(nb_heads, dim_qk, dim_model)
661         self.w_k = randw(nb_heads, dim_qk, dim_model)
662         self.w_v = randw(nb_heads, dim_v, dim_model)
663         self.w_o = randw(dim_v * nb_heads, dim_model)
664
665     def forward(self, bs):
666         x_q = bs.x
667
668         assert (
669             self.causal or bs.complete()
670         ), "Partial evaluation is only possible for causal models"
671
672         if bs.init_cache:
673             self.cache_k = x_q.new_zeros(
674                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
675             )
676             self.cache_v = x_q.new_zeros(
677                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
678             )
679             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
680
681         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
682
683         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
684             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
685         )
686         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
687             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
688         )
689
690         a = torch.einsum(
691             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
692         ) / math.sqrt(self.w_q.size(1))
693
694         if self.causal:
695             if bs.init_cache:
696                 self.cache_attzero = (
697                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
698                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
699                 )
700
701                 if self.horizon is not None:
702                     self.cache_attzero = torch.logical_or(
703                         self.cache_attzero,
704                         torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
705                         >= torch.arange(x_q.size(1), device=q.device)[
706                             None, None, None, :
707                         ]
708                         + self.horizon,
709                     )
710
711             a = a.masked_fill(
712                 self.cache_attzero[
713                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
714                 ],
715                 float("-inf"),
716             )
717
718         a = a.softmax(dim=3)
719
720         if self.record_attention:
721             self.a = a
722
723         a = F.dropout(a, self.attention_dropout, self.training)
724
725         y = torch.einsum(
726             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
727         ).flatten(2)
728
729         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
730
731         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
732
733
734 ##############################
735
736
737 class MyGPT(nn.Module):
738     def __init__(
739         self,
740         vocabulary_size,
741         dim_model,
742         dim_keys,
743         dim_hidden,
744         nb_heads,
745         nb_blocks,
746         nb_lines=None,
747         caterpillar_height=None,
748         causal=False,
749         dropout=0.0,
750         len_max=1e5,
751         attention_layer="caterpillar",
752         logger=print,
753         args=None,
754     ):
755         super().__init__()
756
757         self.vocabulary_size = vocabulary_size
758
759         assert attention_layer in {
760             "mha",
761             "dumbrec",
762             "kvrec",
763             "caterpillar",
764             "attcat",
765         }, f"Unknown attention operator {attention_layer}."
766
767         if attention_layer == "caterpillar" or attention_layer == "attcat":
768             assert nb_lines % caterpillar_height == 0
769             self.caterpillar_length = nb_lines // caterpillar_height
770             self.caterpillar_height = caterpillar_height
771         else:
772             self.caterpillar_length = -1
773             self.caterpillar_height = -1
774
775         assert dim_model % nb_heads == 0
776
777         self.embedding = nn.Sequential(
778             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
779             AddPositionalEncoding(len_max),
780         )
781
782         trunk_blocks = []
783
784         def attlayer():
785             if attention_layer == "mha":
786                 return WithResidual(
787                     CacheWrapper(nn.LayerNorm((dim_model,))),
788                     QKVAttention(
789                         dim_model=dim_model,
790                         dim_qk=dim_keys,
791                         dim_v=dim_model // nb_heads,
792                         nb_heads=nb_heads,
793                         causal=causal,
794                         attention_dropout=dropout,
795                         logger=logger,
796                         args=args,
797                     ),
798                 )
799             elif attention_layer == "dumbrec":
800                 return WithResidual(
801                     CacheWrapper(nn.LayerNorm((dim_model,))),
802                     DumbRec(
803                         dim_model=dim_model,
804                         dim_qk=dim_keys,
805                         dim_v=dim_model // nb_heads,
806                         nb_heads=nb_heads,
807                         nb_lines=nb_lines,
808                         attention_dropout=dropout,
809                         logger=logger,
810                         args=args,
811                     ),
812                 )
813             elif attention_layer == "kvrec":
814                 return WithResidual(
815                     CacheWrapper(nn.LayerNorm((dim_model,))),
816                     KVRec(
817                         dim_model=dim_model,
818                         dim_qk=dim_keys,
819                         dim_v=dim_model // nb_heads,
820                         nb_heads=nb_heads,
821                         nb_lines=nb_lines,
822                         attention_dropout=dropout,
823                         logger=logger,
824                         args=args,
825                     ),
826                 )
827             elif attention_layer == "caterpillar":
828                 return WithResidual(
829                     CacheWrapper(nn.LayerNorm((dim_model,))),
830                     Caterpillar(
831                         dim_model=dim_model,
832                         dim_qk=dim_keys,
833                         dim_v=dim_model // nb_heads,
834                         nb_heads=nb_heads,
835                         caterpillar_length=self.caterpillar_length,
836                         caterpillar_height=self.caterpillar_height,
837                         attention_dropout=dropout,
838                         logger=logger,
839                         args=args,
840                     ),
841                 )
842             elif attention_layer == "attcat":
843                 return nn.Sequential(
844                     WithResidual(
845                         CacheWrapper(nn.LayerNorm((dim_model,))),
846                         QKVAttention(
847                             dim_model=dim_model,
848                             dim_qk=dim_keys,
849                             dim_v=dim_model // nb_heads,
850                             nb_heads=nb_heads,
851                             causal=causal,
852                             horizon=self.caterpillar_length,
853                             attention_dropout=dropout,
854                             logger=logger,
855                             args=args,
856                         ),
857                     ),
858                     WithResidual(
859                         CacheWrapper(nn.LayerNorm((dim_model,))),
860                         Caterpillar(
861                             dim_model=dim_model,
862                             dim_qk=dim_keys,
863                             dim_v=dim_model // nb_heads,
864                             nb_heads=nb_heads,
865                             caterpillar_length=self.caterpillar_length,
866                             caterpillar_height=self.caterpillar_height,
867                             attention_dropout=dropout,
868                             logger=logger,
869                             args=args,
870                         ),
871                     ),
872                 )
873             else:
874                 raise ValueError(f"Unknown attention type {attention_layer}.")
875
876         for b in range(nb_blocks):
877             trunk_blocks += [
878                 attlayer(),
879                 WithResidual(
880                     CacheWrapper(
881                         nn.LayerNorm((dim_model,)),
882                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
883                         nn.ReLU(),
884                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
885                         nn.Dropout(dropout),
886                     ),
887                 ),
888             ]
889
890         self.trunk = nn.Sequential(*trunk_blocks)
891
892         self.readout = CacheWrapper(
893             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
894         )
895
896         with torch.no_grad():
897             for m in self.modules():
898                 if isinstance(m, nn.Embedding):
899                     m.weight.normal_(mean=0, std=2e-2)
900                 elif isinstance(m, nn.LayerNorm):
901                     m.bias.zero_()
902                     m.weight.fill_(1.0)
903
904         self.reset_inner_loss()
905
906     def forward(self, bs):
907         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
908
909         # To make the code simpler in the Caterpillar layer, we pad
910         # here. It's unclear if/how much it hurts computationaly by
911         # increasing the sequence length for the other layers
912
913         if self.caterpillar_length > 0:
914             original_nb = bs.nb
915             if bs.nb % self.caterpillar_length > 0:
916                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
917
918             bs = BracketedSequence(
919                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
920                 bs.first + self.caterpillar_length,
921                 bs.nb,
922                 bs.init_cache,
923             )
924
925         bs = self.embedding(bs)
926         bs = self.trunk(bs)
927         bs = self.readout(bs)
928
929         if self.caterpillar_length > 0:
930             bs = BracketedSequence(
931                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
932                 bs.first - self.caterpillar_length,
933                 original_nb,
934                 bs.init_cache,
935             )
936
937         return bs
938
939     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
940     # 1s where tokens should be generated. The others are kept
941     # unchanged.
942
943     def masked_inplace_autoregression(
944         self,
945         input_src,
946         ar_mask_src,
947         forbidden_tokens=None,
948         deterministic_synthesis=False,
949     ):
950         input = input_src.to(self.readout.f.weight.device)
951         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
952         to_generate = (ar_mask.sum(0) > 0).nonzero()
953         if to_generate.min() > 0:
954             self(
955                 BracketedSequence(input, 0, to_generate.min(), True)
956             )  # Needed to initialize the model's cache
957         for s in range(to_generate.min(), to_generate.max() + 1):
958             output = self(BracketedSequence(input, s, 1, s == 0)).x
959             logits = output[:, s]
960             if forbidden_tokens is not None:
961                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
962             if deterministic_synthesis:
963                 t_next = logits.argmax(1)
964             else:
965                 dist = torch.distributions.categorical.Categorical(logits=logits)
966                 t_next = dist.sample()
967             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
968
969         input_src.copy_(input)
970
971     def reset_inner_loss(self):
972         for m in self.modules():
973             if m is not self and hasattr(m, "reset_inner_loss"):
974                 m.reset_inner_loss()
975
976     def get_inner_loss(self):
977         l = torch.tensor([0.0], device=self.readout.f.weight.device)
978         for m in self.modules():
979             if m is not self and hasattr(m, "get_inner_loss"):
980                 l += m.get_inner_loss()
981         return l
982
983     def record_attention(self, v=True):
984         for m in self.modules():
985             if isinstance(m, QKVAttention):
986                 m.record_attention = v
987
988     def retrieve_attention(self):
989         a = []
990         for m in self.modules():
991             if isinstance(m, QKVAttention):
992                 a.append(m.a)
993         return a
994
995
996 ######################################################################
997
998 if __name__ == "__main__":
999     import argparse
1000
1001     import numpy as np
1002     import matplotlib.pyplot as plt
1003     import matplotlib.collections as mc
1004
1005     args = argparse.Namespace(
1006         gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1007     )
1008
1009     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1010
1011     dim_model, dim_keys, nb_heads = 512, 64, 1
1012     dropout = 0.1
1013
1014     caterpillar = Caterpillar(
1015         dim_model=dim_model,
1016         dim_qk=dim_keys,
1017         dim_v=dim_model // nb_heads,
1018         nb_heads=nb_heads,
1019         caterpillar_length=16,
1020         caterpillar_height=32,
1021         attention_dropout=dropout,
1022         args=args,
1023     ).to(device)
1024
1025     qkv = QKVAttention(
1026         dim_model=dim_model,
1027         dim_qk=dim_keys,
1028         dim_v=dim_model // nb_heads,
1029         nb_heads=nb_heads,
1030         causal=True,
1031         attention_dropout=dropout,
1032         args=args,
1033     ).to(device)
1034
1035     linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1036
1037     x = torch.randn(1, 256, dim_model)
1038
1039     x = x.to(device)
1040     x.requires_grad_()
1041
1042     ######################################################################
1043
1044     fig = plt.figure()
1045     fig.set_figheight(6)
1046     fig.set_figwidth(8)
1047
1048     ax = fig.add_subplot(1, 1, 1)
1049
1050     # ax.set_xlim(-1.5, 1.5)
1051     # ax.set_ylim(-1.5, 1.5)
1052     # ax.set(aspect=1)
1053     # ax.spines.right.set_visible(False)
1054     # ax.spines.top.set_visible(False)
1055
1056     # dt = 0.01
1057     # t = np.arange(dt, 20.0, dt)
1058     # ax.semilogx(t, np.exp(-t / 5.0))
1059     # ax.grid()
1060     ax.set_yscale("log")
1061
1062     ######################################################################
1063
1064     for label, model, thickness in [
1065         ("nn.Linear", linear, 0.2),
1066         ("mygpy.QKVAttention", qkv, 1),
1067         ("mygpt.Caterpillar", caterpillar, 2),
1068     ]:
1069         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1070
1071         for n, p in [("input", x)] + list(model.named_parameters()):
1072             print(f"Processing {model}.{n}")
1073             data = []
1074             for t in range(y.size(1)):
1075                 sg = 0
1076                 for d in torch.randperm(y.size(2))[:8]:
1077                     sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1078                 assert not sg.isinf().any()
1079                 assert not sg.isnan().any()
1080                 data.append([t, sg.sum().item()])
1081
1082             data = torch.tensor(data)
1083             # cx, cy = data[:, 0], data[:, 1]
1084             cy = data[:, 1].sort().values
1085             cx = torch.linspace(0, 1, cy.size(0))
1086             ax.plot(
1087                 cx, cy, label=label + "." + n, linewidth=thickness
1088             )  # , color='gray', label='Input')
1089
1090     # ax.legend(frameon=False, loc="top right")
1091
1092     # Put a legend to the right of the current axis
1093     box = ax.get_position()
1094     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1095     ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1096
1097     filename = "plot.pdf"
1098     print(f"saving {filename}")
1099     fig.savefig(filename, bbox_inches="tight")
1100
1101     # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1102     # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1103     # plt.show()
1104
1105     exit(0)
1106
1107     ######################################################################
1108
1109     m = Caterpillar(
1110         dim_model=4,
1111         dim_qk=3,
1112         dim_v=7,
1113         nb_heads=1,
1114         caterpillar_length=7,
1115         caterpillar_height=3,
1116         attention_dropout=0.0,
1117     )
1118
1119     m.reset_inner_loss()
1120     x = torch.randn(1, 21 + 2 * 7, 4)
1121     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1122     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1123     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1124     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1125     print((y1 - y2).abs().max())
1126     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1127     exit(0)
1128
1129     vocabulary_size = 128
1130     x = torch.randint(vocabulary_size, (6, 1024))
1131
1132     model = MyGPT(
1133         vocabulary_size=vocabulary_size,
1134         dim_model=512,
1135         dim_keys=64,
1136         dim_hidden=2048,
1137         nb_heads=8,
1138         nb_lines=128,
1139         nb_blocks=12,
1140         dropout=0.1,
1141         causal=True,
1142     )
1143
1144     x = x.to(device)
1145     model.to(device)
1146
1147     import time, sys
1148
1149     # import torchvision.models as models
1150     # from torch.profiler import profile, record_function, ProfilerActivity
1151
1152     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1153     # with record_function("model_inference"):
1154
1155     model.eval()
1156     for i in range(3):
1157         start_time = time.perf_counter()
1158         for k in range(10):
1159             model(BracketedSequence(x))
1160         duration = time.perf_counter() - start_time
1161         print(duration)
1162         sys.stdout.flush()
1163
1164     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1165     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1166
1167     # print("##############################################################")
1168     # y2 = torch.randn_like(y1)
1169     # for s in range(x.size(1)):
1170     # z = model(BracketedSequence(x, s, 1))
1171     # y2[:, s : s + 1] = z.slice()
1172
1173     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1174
1175 ######################################################################