Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150     with torch.no_grad():
151         s_A, s_X = 0, 0
152         for t in range(X.size(dim) - 1, 0, -1):
153             delta = (grad_Y[t] - s_A) / A[t].grad
154             s_A += A[t].grad * delta
155             A[t].grad = delta
156             delta = (grad_Y[t] - s_X) / X[t].grad
157             s_X += X[t].grad * delta
158             X[t].grad = delta
159
160
161 def pscan_shape(A, X, Y_init):
162     s = X.size()
163     A = A.reshape(-1, s[-2])
164     X = X.reshape(-1, s[-2], s[-1])
165
166     if Y_init is None:
167         Y_init = X.new_zeros(X.size(0), s[-1])
168     else:
169         Y_init = Y_init.reshape(-1, s[-1])
170
171     Y = pscan.pscan(A, X, Y_init).reshape(s)
172
173     return Y
174
175
176 def nsum_shape(X, Y_init):
177     s = X.size()
178     X = X.reshape(-1, s[-2], s[-1])  # ntd
179
180     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
181     result = []
182
183     for k in range(X.size(1)):
184         Y = Y + X[:, k]
185         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
186         result.append(Y)
187
188     return torch.cat(result, dim=1).reshape(s)
189
190
191 ##############################
192
193
194 class DumbRec(nn.Module):
195     def __init__(
196         self,
197         dim_model,
198         dim_qk,
199         dim_v,
200         nb_heads,
201         nb_lines,
202         attention_dropout=0.0,
203         len_max=1e5,
204         logger=print,
205         **kwargs,
206     ):
207         super().__init__()
208
209         def randw(*d):
210             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
211
212         self.nb_lines = nb_lines
213         self.attention_dropout = attention_dropout
214
215         self.k_star = randw(nb_lines, dim_qk)
216
217         self.w_qw = randw(nb_heads, dim_qk, dim_model)
218         self.w_qr = randw(nb_heads, dim_qk, dim_model)
219         # self.w_k = randw(nb_heads, dim_qk, dim_model)
220         self.w_v = randw(nb_heads, dim_v, dim_model)
221         self.w_o = randw(dim_v * nb_heads, dim_model)
222
223     def reset_inner_loss(self):
224         self.acc_attention = 0
225         self.acc_nb = 0
226
227     def get_inner_loss(self):
228         warnings.warn("l2 regularization", RuntimeWarning)
229         return (self.acc_attention / self.acc_nb).pow(2).sum()
230         # return torch.tensor([0], device=self.w_qw.device)
231
232     def forward(self, bs):
233         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
234
235         if bs.init_cache:
236             self.rec_v = x_q.new_zeros(
237                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
238             )
239             # self.rec_k = x_q.new_zeros(
240             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
241             # )
242             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
243
244         ######################################################################
245         # Prepare the keys
246
247         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
248
249         warnings.warn("rotating key barrel", RuntimeWarning)
250         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251         t_barrel = torch.arange(t0, t1, device=k_star.device)
252         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
253         l_barrel = (
254             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
255         ) % k_star.size(0)
256         k_star = k_star[l_barrel, t_barrel]
257
258         ######################################################################
259         # Compute the recurrent state
260
261         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
262
263         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
265
266         aw = torch.einsum(
267             "nhtd,ltd->nhlt",
268             qw,
269             k_star,
270         ) / math.sqrt(self.w_qw.size(1))
271
272         aw = aw.softmax(dim=2)  # nhlt
273
274         if self.train:
275             self.acc_attention += aw.sum(dim=(0, 1, 3))
276             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
277
278         aw = F.dropout(aw, self.attention_dropout, self.training)
279
280         A = 1 - aw.sum(dim=1)  # nlt
281
282         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
284
285         if t0 == 0:
286             V0 = None
287             # K0 = None
288         else:
289             V0 = self.rec_v[:, :, t0 - 1]
290             # K0 = self.rec_k[:, :, t0 - 1]
291
292         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
294
295         ######################################################################
296         # compute the readout
297
298         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
299
300         ar = torch.einsum(
301             "nhtd,ld->nhlt",
302             qr,
303             # self.rec_k[:, :, t0:t1],
304             self.k_star,
305         ) / math.sqrt(self.w_qr.size(1))
306
307         ar = ar.softmax(dim=2)  # nhlt
308
309         ar = F.dropout(ar, self.attention_dropout, self.training)
310
311         y = torch.einsum(
312             "nhlt,nltd->nthd",
313             ar,
314             self.rec_v[:, :, t0:t1],
315         ).flatten(2)
316
317         self.cache_y[:, t0:t1] = y @ self.w_o
318
319         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
320
321
322 ##############################
323
324
325 class KVRec(nn.Module):
326     def __init__(
327         self,
328         dim_model,
329         dim_qk,
330         dim_v,
331         nb_heads,
332         nb_lines,
333         attention_dropout=0.0,
334         len_max=1e5,
335         logger=print,
336         **kwargs,
337     ):
338         super().__init__()
339
340         def randw(*d):
341             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
342
343         self.nb_lines = nb_lines
344         self.attention_dropout = attention_dropout
345
346         self.k_star = randw(nb_lines, dim_qk)
347
348         self.w_qw = randw(nb_heads, dim_qk, dim_model)
349         self.w_qr = randw(nb_heads, dim_qk, dim_model)
350         self.w_k = randw(nb_heads, dim_qk, dim_model)
351         self.w_v = randw(nb_heads, dim_v, dim_model)
352         self.w_o = randw(dim_v * nb_heads, dim_model)
353
354     def reset_inner_loss(self):
355         self.acc_attention = 0
356         self.acc_nb = 0
357
358     def get_inner_loss(self):
359         warnings.warn("l2 regularization", RuntimeWarning)
360         return (self.acc_attention / self.acc_nb).pow(2).sum()
361         # return torch.tensor([0], device=self.w_qw.device)
362         # warnings.warn("side regularization", RuntimeWarning)
363         # return (
364         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
365         # )
366         # return torch.tensor([0], device=self.w_qw.device)
367
368     def forward(self, bs):
369         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
370
371         if bs.init_cache:
372             self.rec_v = x_q.new_zeros(
373                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
374             )
375             self.rec_k = x_q.new_zeros(
376                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
377             )
378             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
379
380         ######################################################################
381         # Prepare the keys
382
383         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
384
385         warnings.warn("rotating key barrel", RuntimeWarning)
386         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387         t_barrel = torch.arange(t0, t1, device=k_star.device)
388         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
389         l_barrel = (
390             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
391         ) % k_star.size(0)
392         k_star = k_star[l_barrel, t_barrel]
393
394         ######################################################################
395         # Compute the recurrent state
396
397         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
398
399         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
401
402         aw = torch.einsum(
403             "nhtd,ltd->nhlt",
404             qw,
405             k_star,
406         ) / math.sqrt(self.w_qw.size(1))
407
408         aw = aw.softmax(dim=2)  # nhlt
409
410         if self.train:
411             # We want all the memory lines to be used similarly
412             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
413             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
414
415         aw = F.dropout(aw, self.attention_dropout, self.training)
416
417         A = 1 - aw.sum(dim=1)  # nlt
418
419         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
421
422         if t0 == 0:
423             V0 = None
424             K0 = None
425         else:
426             V0 = self.rec_v[:, :, t0 - 1]
427             K0 = self.rec_k[:, :, t0 - 1]
428
429         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
431
432         ######################################################################
433         # compute the readout
434
435         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
436
437         ar = torch.einsum(
438             "nhtd,nltd->nhlt",
439             qr,
440             self.rec_k[:, :, t0:t1],
441         ) / math.sqrt(self.w_qr.size(1))
442
443         ar = ar.softmax(dim=2)  # nhlt
444
445         ar = F.dropout(ar, self.attention_dropout, self.training)
446
447         y = torch.einsum(
448             "nhlt,nltd->nthd",
449             ar,
450             self.rec_v[:, :, t0:t1],
451         ).flatten(2)
452
453         self.cache_y[:, t0:t1] = y @ self.w_o
454
455         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
456
457
458 ##############################
459
460
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
464
465
466 def moving_window(x, dim, win_dim, win_size):
467     size, stride = x.size(), x.stride()
468     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469     size = size[:win_dim] + (win_size,) + size[win_dim:]
470     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
471
472     return x.as_strided(size=size, stride=stride)
473
474
475 ##############################
476
477
478 class Caterpillar(nn.Module):
479     def __init__(
480         self,
481         dim_model,
482         dim_qk,
483         dim_v,
484         nb_heads,
485         caterpillar_length,
486         caterpillar_height,
487         attention_dropout=0.0,
488         len_max=1e5,
489         logger=print,
490         **kwargs,
491     ):
492         super().__init__()
493
494         warnings.warn("Caterpillar", RuntimeWarning)
495
496         def randw(*d, amplitude=None):
497             if amplitude is None:
498                 amplitude = 1 / math.sqrt(d[-1])
499             return nn.Parameter(amplitude * torch.randn(*d))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         ######################################################################
506         # sup_args
507
508         x = kwargs.get("gate_dropout")
509         if x is None:
510             self.proba_gate_dropout = 0.0
511         else:
512             self.proba_gate_dropout = float(x)
513
514         logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
515
516         x = kwargs.get("default_bg")
517         if x is None:
518             default_bg = -math.log(caterpillar_height - 1)
519         else:
520             default_bg = float(x)
521
522         logger(f"default_bg {default_bg}")
523
524         ######################################################################
525
526         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
527         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
528
529         self.w_K = randw(nb_heads, dim_qk, dim_model)
530         self.w_V = randw(nb_heads, dim_v, dim_model)
531         self.w_Q = randw(nb_heads, dim_qk, dim_model)
532         self.w_O = randw(dim_v * nb_heads, dim_model)
533
534         self.init_K_rec = randw(
535             caterpillar_height,
536             caterpillar_length,
537             dim_qk,
538         )
539         self.init_V_rec = randw(
540             caterpillar_height,
541             caterpillar_length,
542             dim_v,
543         )
544
545     def reset_inner_loss(self):
546         self.acc_attention = 0
547         self.acc_nb = 0
548
549     def get_inner_loss(self):
550         # warnings.warn("l2 regularization", RuntimeWarning)
551         # return (self.acc_attention / self.acc_nb).pow(2).sum()
552         return torch.tensor([0], device=self.w_Q.device)
553
554     def forward(self, bs):
555         # Dimensions to make the source a bit clearer, that's needed
556
557         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
558
559         N = bs.x.size(0)
560         T = bs.x.size(1)
561         H = self.w_V.size(0)
562         DV = self.w_V.size(1)
563         DK = self.w_K.size(1)
564         DM = self.w_O.size(1)
565         R = self.caterpillar_height
566         L = self.caterpillar_length
567
568         assert (
569             t0 >= L and (t1 - t0) % L == 0
570         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
571
572         # We cache values to deal efficiently with auto-regression
573
574         if bs.init_cache:
575             self.rec_V = X.new_zeros(N, R, T, DV)
576             self.rec_K = X.new_zeros(N, R, T, DK)
577             # We start the recurrent sequences with optimizable
578             # initial values. No idea if it helps.
579             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
580             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
581
582             self.cache_Y = X.new_zeros(N, T, DM)
583
584         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
585         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
586
587         ######################################################################
588         # Compute the recurrent state
589
590         # This is the Gating sequence that modulates the storing of
591         # the new key and value in the R pairs of the current
592         # stack. There are R independent gating values, which means
593         # that the current K/V may be stored in multiple pairs of the
594         # recurrent state, or not at all.
595
596         G = (
597             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
598         ).sigmoid()
599
600         # Clip the gating to avoid values greater than 1 when several
601         # heads hit the same row
602
603         G = G / G.sum(1, keepdim=True).clamp(min=1)
604
605         ######################################################################
606
607         def recurrence(G, V, K):
608             # We prepare the arguments for the parallel scan
609
610             A = 1 - G.sum(1)
611
612             gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
613             gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
614
615             # We start from cached values, which matters in inference
616
617             init_rec_V = self.rec_V[:, :, t0 - L : t0]
618             init_rec_K = self.rec_K[:, :, t0 - L : t0]
619
620             # Here there is a trick: Since the stack at position t is
621             # computed by updating that at position t-L, the parallel
622             # scan operates with a period of L. To do so we split the
623             # sequence indexing in two axes, the second of size L, and
624             # run the parallel scan using the first as the sequence index.
625
626             A = A.unflatten(2, (-1, L))
627             gated_V = gated_V.unflatten(2, (-1, L))
628             gated_K = gated_K.unflatten(2, (-1, L))
629
630             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
631             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
632
633             next_V = next_V.flatten(2, 3)
634             next_K = next_K.flatten(2, 3)
635
636             return next_V, next_K
637
638         #################################################################
639
640         next_V, next_K = recurrence(G, V, K)
641
642         if self.training and self.proba_gate_dropout > 0.0:
643             # G is NxHxRxT where r is the caterpillar's row.
644
645             warnings.warn("gate dropout", RuntimeWarning)
646
647             # Pick a point in each of the NxHxR timeline and set this
648             # entry and the following to 1
649             kill = (
650                 torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
651             ).cumsum(dim=3)
652
653             # Keep these mask for only some of the NxHxR
654             kill = kill * (
655                 torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
656             )
657
658             # The coefficient to keep are the complementary
659             mask = 1 - kill
660
661             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
662
663             next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
664                 1 - self.proba_gate_dropout
665             )
666             next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
667                 1 - self.proba_gate_dropout
668             )
669
670         self.rec_V[:, :, t0:t1] = next_V
671         self.rec_K[:, :, t0:t1] = next_K
672
673         ######################################################################
674         # compute the readout
675
676         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
677
678         # We build tensors NxHxTxRxL where N is the sample index, H
679         # the head, T the time, R the row in the caterpillar, and L
680         # the column in the caterpillar
681
682         windowed_V = moving_window(
683             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
684         )
685
686         windowed_K = moving_window(
687             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
688         )
689
690         # We have an attention score for each of the RxL values
691
692         ar = torch.einsum(
693             "nhtd,nrtld->nhtrl",
694             Q,
695             windowed_K,
696         ) / math.sqrt(DK)
697
698         # softmax can operate only on one dimension, hence the
699         # flattening
700
701         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
702
703         ar = F.dropout(ar, self.attention_dropout, self.training)
704
705         # Compute the output for each head, flatten to concatenate
706
707         Y = torch.einsum(
708             "nhtfl,nftld->nthd",
709             ar,
710             windowed_V,
711         ).flatten(2)
712
713         # Compute the final output
714
715         self.cache_Y[:, t0:t1] = Y @ self.w_O
716
717         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
718
719
720 ##############################
721
722
723 class QKVAttention(nn.Module):
724     def __init__(
725         self,
726         dim_model,
727         dim_qk,
728         dim_v,
729         nb_heads=1,
730         causal=False,
731         attention_dropout=0.0,
732         logger=print,
733         **kwargs,
734     ):
735         super().__init__()
736
737         def randw(*d):
738             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
739
740         self.causal = causal
741         self.attention_dropout = attention_dropout
742         self.record_attention = False
743
744         self.w_q = randw(nb_heads, dim_qk, dim_model)
745         self.w_k = randw(nb_heads, dim_qk, dim_model)
746         self.w_v = randw(nb_heads, dim_v, dim_model)
747         self.w_o = randw(dim_v * nb_heads, dim_model)
748
749     def forward(self, bs):
750         x_q = bs.x
751
752         assert (
753             self.causal or bs.complete()
754         ), "Partial evaluation is only possible for causal models"
755
756         if bs.init_cache:
757             self.cache_k = x_q.new_zeros(
758                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
759             )
760             self.cache_v = x_q.new_zeros(
761                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
762             )
763             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
764
765         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
766
767         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
768             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
769         )
770         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
771             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
772         )
773
774         a = torch.einsum(
775             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
776         ) / math.sqrt(self.w_q.size(1))
777
778         if self.causal:
779             if bs.init_cache:
780                 self.cache_attzero = (
781                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
782                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
783                 )
784             a = a.masked_fill(
785                 self.cache_attzero[
786                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
787                 ],
788                 float("-inf"),
789             )
790
791         a = a.softmax(dim=3)
792
793         if self.record_attention:
794             self.a = a
795
796         a = F.dropout(a, self.attention_dropout, self.training)
797
798         y = torch.einsum(
799             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
800         ).flatten(2)
801
802         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
803
804         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
805
806
807 ##############################
808
809
810 class MyGPT(nn.Module):
811     def __init__(
812         self,
813         vocabulary_size,
814         dim_model,
815         dim_keys,
816         dim_hidden,
817         nb_heads,
818         nb_blocks,
819         nb_lines=None,
820         caterpillar_height=None,
821         causal=False,
822         dropout=0.0,
823         len_max=1e5,
824         attention_layer="kvrec",
825         logger=print,
826         **kwargs,
827     ):
828         super().__init__()
829
830         assert attention_layer in {
831             "mha",
832             "dumbrec",
833             "kvrec",
834             "caterpillar",
835         }, f"Unknown attention operator {attention_layer}."
836
837         if attention_layer == "caterpillar":
838             assert nb_lines % caterpillar_height == 0
839             self.caterpillar_length = nb_lines // caterpillar_height
840             self.caterpillar_height = caterpillar_height
841         else:
842             self.caterpillar_length = -1
843             self.caterpillar_height = -1
844
845         assert dim_model % nb_heads == 0
846
847         self.embedding = nn.Sequential(
848             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
849             AddPositionalEncoding(len_max),
850         )
851
852         trunk_blocks = []
853
854         def attlayer():
855             if attention_layer == "mha":
856                 return QKVAttention(
857                     dim_model=dim_model,
858                     dim_qk=dim_keys,
859                     dim_v=dim_model // nb_heads,
860                     nb_heads=nb_heads,
861                     causal=causal,
862                     attention_dropout=dropout,
863                     logger=logger,
864                     **kwargs,
865                 )
866             elif attention_layer == "dumbrec":
867                 return DumbRec(
868                     dim_model=dim_model,
869                     dim_qk=dim_keys,
870                     dim_v=dim_model // nb_heads,
871                     nb_heads=nb_heads,
872                     nb_lines=nb_lines,
873                     attention_dropout=dropout,
874                     logger=logger,
875                     **kwargs,
876                 )
877             elif attention_layer == "kvrec":
878                 return KVRec(
879                     dim_model=dim_model,
880                     dim_qk=dim_keys,
881                     dim_v=dim_model // nb_heads,
882                     nb_heads=nb_heads,
883                     nb_lines=nb_lines,
884                     attention_dropout=dropout,
885                     logger=logger,
886                     **kwargs,
887                 )
888             elif attention_layer == "caterpillar":
889                 return Caterpillar(
890                     dim_model=dim_model,
891                     dim_qk=dim_keys,
892                     dim_v=dim_model // nb_heads,
893                     nb_heads=nb_heads,
894                     caterpillar_length=self.caterpillar_length,
895                     caterpillar_height=self.caterpillar_height,
896                     attention_dropout=dropout,
897                     logger=logger,
898                     **kwargs,
899                 )
900             else:
901                 raise ValueError(f"Unknown attention type {attention_layer}.")
902
903         for b in range(nb_blocks):
904             trunk_blocks += [
905                 WithResidual(
906                     CacheWrapper(nn.LayerNorm((dim_model,))),
907                     attlayer(),
908                 ),
909                 WithResidual(
910                     CacheWrapper(
911                         nn.LayerNorm((dim_model,)),
912                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
913                         nn.ReLU(),
914                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
915                         nn.Dropout(dropout),
916                     ),
917                 ),
918             ]
919
920         self.trunk = nn.Sequential(*trunk_blocks)
921
922         self.readout = CacheWrapper(
923             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
924         )
925
926         with torch.no_grad():
927             for m in self.modules():
928                 if isinstance(m, nn.Embedding):
929                     m.weight.normal_(mean=0, std=2e-2)
930                 elif isinstance(m, nn.LayerNorm):
931                     m.bias.zero_()
932                     m.weight.fill_(1.0)
933
934         self.reset_inner_loss()
935
936     def forward(self, bs):
937         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
938
939         # To make the code simpler in the Caterpillar layer, we pad
940         # here. It's unclear if/how much it hurts computationaly by
941         # increasing the sequence length for the other layers
942
943         if self.caterpillar_length > 0:
944             original_nb = bs.nb
945             if bs.nb % self.caterpillar_length > 0:
946                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
947
948             bs = BracketedSequence(
949                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
950                 bs.first + self.caterpillar_length,
951                 bs.nb,
952                 bs.init_cache,
953             )
954
955         bs = self.embedding(bs)
956         bs = self.trunk(bs)
957         bs = self.readout(bs)
958
959         if self.caterpillar_length > 0:
960             bs = BracketedSequence(
961                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
962                 bs.first - self.caterpillar_length,
963                 original_nb,
964                 bs.init_cache,
965             )
966
967         return bs
968
969     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
970     # 1s where tokens should be generated. The others are kept
971     # unchanged.
972
973     def masked_inplace_autoregression(
974         self,
975         input_src,
976         ar_mask_src,
977         forbidden_tokens=None,
978         deterministic_synthesis=False,
979     ):
980         input = input_src.to(self.readout.f.weight.device)
981         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
982         to_generate = (ar_mask.sum(0) > 0).nonzero()
983         if to_generate.min() > 0:
984             self(
985                 BracketedSequence(input, 0, to_generate.min(), True)
986             )  # Needed to initialize the model's cache
987         for s in range(to_generate.min(), to_generate.max() + 1):
988             output = self(BracketedSequence(input, s, 1, s == 0)).x
989             logits = output[:, s]
990             if forbidden_tokens is not None:
991                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
992             if deterministic_synthesis:
993                 t_next = logits.argmax(1)
994             else:
995                 dist = torch.distributions.categorical.Categorical(logits=logits)
996                 t_next = dist.sample()
997             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
998
999         input_src.copy_(input)
1000
1001     def reset_inner_loss(self):
1002         for m in self.modules():
1003             if m is not self and hasattr(m, "reset_inner_loss"):
1004                 m.reset_inner_loss()
1005
1006     def get_inner_loss(self):
1007         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1008         for m in self.modules():
1009             if m is not self and hasattr(m, "get_inner_loss"):
1010                 l += m.get_inner_loss()
1011         return l
1012
1013     def record_attention(self, v=True):
1014         for m in self.modules():
1015             if isinstance(m, QKVAttention):
1016                 m.record_attention = v
1017
1018     def retrieve_attention(self):
1019         a = []
1020         for m in self.modules():
1021             if isinstance(m, QKVAttention):
1022                 a.append(m.a)
1023         return a
1024
1025
1026 ######################################################################
1027
1028 if __name__ == "__main__":
1029     print("Basic check.")
1030
1031     m = Caterpillar(
1032         dim_model=4,
1033         dim_qk=3,
1034         dim_v=7,
1035         nb_heads=1,
1036         caterpillar_length=7,
1037         caterpillar_height=3,
1038         attention_dropout=0.0,
1039     )
1040
1041     m.reset_inner_loss()
1042     x = torch.randn(1, 21 + 2 * 7, 4)
1043     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1044     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1045     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1046     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1047     print((y1 - y2).abs().max())
1048     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1049     exit(0)
1050
1051     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1052
1053     vocabulary_size = 128
1054     x = torch.randint(vocabulary_size, (6, 1024))
1055
1056     model = MyGPT(
1057         vocabulary_size=vocabulary_size,
1058         dim_model=512,
1059         dim_keys=64,
1060         dim_hidden=2048,
1061         nb_heads=8,
1062         nb_lines=128,
1063         nb_blocks=12,
1064         dropout=0.1,
1065         causal=True,
1066     )
1067
1068     x = x.to(device)
1069     model.to(device)
1070
1071     import time, sys
1072
1073     # import torchvision.models as models
1074     # from torch.profiler import profile, record_function, ProfilerActivity
1075
1076     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1077     # with record_function("model_inference"):
1078
1079     model.eval()
1080     for i in range(3):
1081         start_time = time.perf_counter()
1082         for k in range(10):
1083             model(BracketedSequence(x))
1084         duration = time.perf_counter() - start_time
1085         print(duration)
1086         sys.stdout.flush()
1087
1088     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1089     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1090
1091     # print("##############################################################")
1092     # y2 = torch.randn_like(y1)
1093     # for s in range(x.size(1)):
1094     # z = model(BracketedSequence(x, s, 1))
1095     # y2[:, s : s + 1] = z.slice()
1096
1097     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1098
1099 ######################################################################