Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193         logger=print,
194         **kwargs,
195     ):
196         super().__init__()
197
198         def randw(*d):
199             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
200
201         self.nb_lines = nb_lines
202         self.attention_dropout = attention_dropout
203
204         self.k_star = randw(nb_lines, dim_qk)
205
206         self.w_qw = randw(nb_heads, dim_qk, dim_model)
207         self.w_qr = randw(nb_heads, dim_qk, dim_model)
208         # self.w_k = randw(nb_heads, dim_qk, dim_model)
209         self.w_v = randw(nb_heads, dim_v, dim_model)
210         self.w_o = randw(dim_v * nb_heads, dim_model)
211
212     def reset_inner_loss(self):
213         self.acc_attention = 0
214         self.acc_nb = 0
215
216     def get_inner_loss(self):
217         warnings.warn("l2 regularization", RuntimeWarning)
218         return (self.acc_attention / self.acc_nb).pow(2).sum()
219         # return torch.tensor([0], device=self.w_qw.device)
220
221     def forward(self, bs):
222         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
223
224         if bs.init_cache:
225             self.rec_v = x_q.new_zeros(
226                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
227             )
228             # self.rec_k = x_q.new_zeros(
229             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
230             # )
231             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
232
233         ######################################################################
234         # Prepare the keys
235
236         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
237
238         warnings.warn("rotating key barrel", RuntimeWarning)
239         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
240         t_barrel = torch.arange(t0, t1, device=k_star.device)
241         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
242         l_barrel = (
243             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
244         ) % k_star.size(0)
245         k_star = k_star[l_barrel, t_barrel]
246
247         ######################################################################
248         # Compute the recurrent state
249
250         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
251
252         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
253         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
254
255         aw = torch.einsum(
256             "nhtd,ltd->nhlt",
257             qw,
258             k_star,
259         ) / math.sqrt(self.w_qw.size(1))
260
261         aw = aw.softmax(dim=2)  # nhlt
262
263         if self.train:
264             self.acc_attention += aw.sum(dim=(0, 1, 3))
265             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
266
267         aw = F.dropout(aw, self.attention_dropout, self.training)
268
269         A = 1 - aw.sum(dim=1)  # nlt
270
271         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
272         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
273
274         if t0 == 0:
275             V0 = None
276             # K0 = None
277         else:
278             V0 = self.rec_v[:, :, t0 - 1]
279             # K0 = self.rec_k[:, :, t0 - 1]
280
281         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
282         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
283
284         ######################################################################
285         # compute the readout
286
287         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288
289         ar = torch.einsum(
290             "nhtd,ld->nhlt",
291             qr,
292             # self.rec_k[:, :, t0:t1],
293             self.k_star,
294         ) / math.sqrt(self.w_qr.size(1))
295
296         ar = ar.softmax(dim=2)  # nhlt
297
298         ar = F.dropout(ar, self.attention_dropout, self.training)
299
300         y = torch.einsum(
301             "nhlt,nltd->nthd",
302             ar,
303             self.rec_v[:, :, t0:t1],
304         ).flatten(2)
305
306         self.cache_y[:, t0:t1] = y @ self.w_o
307
308         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
309
310
311 ##############################
312
313
314 class KVRec(nn.Module):
315     def __init__(
316         self,
317         dim_model,
318         dim_qk,
319         dim_v,
320         nb_heads,
321         nb_lines,
322         attention_dropout=0.0,
323         len_max=1e5,
324         logger=print,
325         **kwargs,
326     ):
327         super().__init__()
328
329         def randw(*d):
330             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
331
332         self.nb_lines = nb_lines
333         self.attention_dropout = attention_dropout
334
335         self.k_star = randw(nb_lines, dim_qk)
336
337         self.w_qw = randw(nb_heads, dim_qk, dim_model)
338         self.w_qr = randw(nb_heads, dim_qk, dim_model)
339         self.w_k = randw(nb_heads, dim_qk, dim_model)
340         self.w_v = randw(nb_heads, dim_v, dim_model)
341         self.w_o = randw(dim_v * nb_heads, dim_model)
342
343     def reset_inner_loss(self):
344         self.acc_attention = 0
345         self.acc_nb = 0
346
347     def get_inner_loss(self):
348         warnings.warn("l2 regularization", RuntimeWarning)
349         return (self.acc_attention / self.acc_nb).pow(2).sum()
350         # return torch.tensor([0], device=self.w_qw.device)
351         # warnings.warn("side regularization", RuntimeWarning)
352         # return (
353         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
354         # )
355         # return torch.tensor([0], device=self.w_qw.device)
356
357     def forward(self, bs):
358         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
359
360         if bs.init_cache:
361             self.rec_v = x_q.new_zeros(
362                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
363             )
364             self.rec_k = x_q.new_zeros(
365                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
366             )
367             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
368
369         ######################################################################
370         # Prepare the keys
371
372         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
373
374         warnings.warn("rotating key barrel", RuntimeWarning)
375         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
376         t_barrel = torch.arange(t0, t1, device=k_star.device)
377         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
378         l_barrel = (
379             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
380         ) % k_star.size(0)
381         k_star = k_star[l_barrel, t_barrel]
382
383         ######################################################################
384         # Compute the recurrent state
385
386         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
387
388         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
389         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
390
391         aw = torch.einsum(
392             "nhtd,ltd->nhlt",
393             qw,
394             k_star,
395         ) / math.sqrt(self.w_qw.size(1))
396
397         aw = aw.softmax(dim=2)  # nhlt
398
399         if self.train:
400             # We want all the memory lines to be used similarly
401             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
402             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
403
404         aw = F.dropout(aw, self.attention_dropout, self.training)
405
406         A = 1 - aw.sum(dim=1)  # nlt
407
408         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
409         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
410
411         if t0 == 0:
412             V0 = None
413             K0 = None
414         else:
415             V0 = self.rec_v[:, :, t0 - 1]
416             K0 = self.rec_k[:, :, t0 - 1]
417
418         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
419         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
420
421         ######################################################################
422         # compute the readout
423
424         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
425
426         ar = torch.einsum(
427             "nhtd,nltd->nhlt",
428             qr,
429             self.rec_k[:, :, t0:t1],
430         ) / math.sqrt(self.w_qr.size(1))
431
432         ar = ar.softmax(dim=2)  # nhlt
433
434         ar = F.dropout(ar, self.attention_dropout, self.training)
435
436         y = torch.einsum(
437             "nhlt,nltd->nthd",
438             ar,
439             self.rec_v[:, :, t0:t1],
440         ).flatten(2)
441
442         self.cache_y[:, t0:t1] = y @ self.w_o
443
444         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
445
446
447 ##############################
448
449
450 # Returns a tensor with an additional index at rank win_dim, that move
451 # along the same dimension as dim, on a domain {0...win_size-1}, and
452 # dim is restricted on a domain reduced by win_size-1 values.
453
454
455 def moving_window(x, dim, win_dim, win_size):
456     size, stride = x.size(), x.stride()
457     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
458     size = size[:win_dim] + (win_size,) + size[win_dim:]
459     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
460
461     return x.as_strided(size=size, stride=stride)
462
463
464 ##############################
465
466
467 class Caterpillar(nn.Module):
468     def __init__(
469         self,
470         dim_model,
471         dim_qk,
472         dim_v,
473         nb_heads,
474         caterpillar_length,
475         caterpillar_height,
476         attention_dropout=0.0,
477         len_max=1e5,
478         logger=print,
479         **kwargs,
480     ):
481         super().__init__()
482
483         warnings.warn("Caterpillar", RuntimeWarning)
484
485         def randw(*d, amplitude=None):
486             if amplitude is None:
487                 amplitude = 1 / math.sqrt(d[-1])
488             return nn.Parameter(amplitude * torch.randn(*d))
489
490         self.caterpillar_length = caterpillar_length
491         self.caterpillar_height = caterpillar_height
492         self.attention_dropout = attention_dropout
493
494         self.proba_gate_dropout = 0.0
495
496         default_bg = kwargs.get("default_bg")
497         if default_bg is None:
498             default_bg = -math.log(caterpillar_height - 1)
499         else:
500             default_bg = float(default_bg)
501
502         logger(f"default_bg {default_bg}")
503
504         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
505         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
506
507         self.w_K = randw(nb_heads, dim_qk, dim_model)
508         self.w_V = randw(nb_heads, dim_v, dim_model)
509         self.w_Q = randw(nb_heads, dim_qk, dim_model)
510         self.w_O = randw(dim_v * nb_heads, dim_model)
511
512         self.init_K_rec = randw(
513             caterpillar_height,
514             caterpillar_length,
515             dim_qk,
516         )
517         self.init_V_rec = randw(
518             caterpillar_height,
519             caterpillar_length,
520             dim_v,
521         )
522
523     def reset_inner_loss(self):
524         self.acc_attention = 0
525         self.acc_nb = 0
526
527     def get_inner_loss(self):
528         # warnings.warn("l2 regularization", RuntimeWarning)
529         # return (self.acc_attention / self.acc_nb).pow(2).sum()
530         return torch.tensor([0], device=self.w_Q.device)
531
532     def forward(self, bs):
533         # Dimensions to make the source a bit clearer, that's needed
534
535         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
536
537         N = bs.x.size(0)
538         T = bs.x.size(1)
539         H = self.w_V.size(0)
540         DV = self.w_V.size(1)
541         DK = self.w_K.size(1)
542         DM = self.w_O.size(1)
543         R = self.caterpillar_height
544         L = self.caterpillar_length
545
546         assert (
547             t0 >= L and (t1 - t0) % L == 0
548         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
549
550         # We cache values to deal efficiently with auto-regression
551
552         if bs.init_cache:
553             self.rec_V = X.new_zeros(N, R, T, DV)
554             self.rec_K = X.new_zeros(N, R, T, DK)
555             # We start the recurrent sequences with optimizable
556             # initial values. No idea if it helps.
557             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
558             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
559
560             self.cache_Y = X.new_zeros(N, T, DM)
561
562         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
563         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
564
565         ######################################################################
566         # Compute the recurrent state
567
568         # This is the Gating sequence that modulates the storing of
569         # the new key and value in the R pairs of the current
570         # stack. There are R independent gating values, which means
571         # that the current K/V may be stored in multiple pairs of the
572         # recurrent state, or not at all.
573
574         G = (
575             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
576         ).sigmoid()
577
578         # Clip the gating to avoid values greater than 1 when several
579         # heads hit the same row
580
581         G = G / G.sum(1, keepdim=True).clamp(min=1)
582
583         ######################################################################
584         # Roll the gating indexes
585
586         # warnings.warn("rotating barrel", RuntimeWarning)
587
588         # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
589         # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
590         # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
591         # G = G.gather(dim=2, index=r_barrel.expand_as(G))
592
593         ######################################################################
594         # The "flashbacks"
595
596         if self.training and self.proba_gate_dropout > 0.0:
597             # This is a better implementation of "flashbacks".
598
599             # G is NxHxExT where e is the caterpillar's row.
600
601             warnings.warn("gate dropout", RuntimeWarning)
602             epsilon = 0.5
603
604             dropout_head = (
605                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
606                 .expand_as(G)
607                 .float()
608             )
609
610             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
611
612             dropout_active = (
613                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
614             ).long()
615
616             dropout_head *= dropout_active
617             dropout_tail *= dropout_active
618
619             G = (
620                 G
621                 + dropout_head * (1 - epsilon - G.detach())
622                 - dropout_tail * G.detach()
623             )
624
625         ######################################################################
626
627         # We prepare the arguments for the parallel scan
628
629         A = 1 - G.sum(1)
630
631         # warnings.warn("harmonic recurrence", RuntimeWarning)
632         # har = torch.arange(t0, t1, device = G.device).float() + 1
633         # A = har / (har + 1)
634         # G = G / har
635
636         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
637         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
638
639         # We start from cached values, which matters in inference
640
641         init_rec_V = self.rec_V[:, :, t0 - L : t0]
642         init_rec_K = self.rec_K[:, :, t0 - L : t0]
643
644         #################################################################
645         # Associative scan
646
647         # Here there is a trick: Since the stack at position t is
648         # computed by updating that at position t-L, the parallel
649         # scan operates with a period of L. To do so we split the
650         # sequence indexing in two axes, the second of size L, and
651         # run the parallel scan using the first as the sequence index.
652
653         A = A.unflatten(2, (-1, L))
654         gated_V = gated_V.unflatten(2, (-1, L))
655         gated_K = gated_K.unflatten(2, (-1, L))
656
657         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
658         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
659
660         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
661         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
662
663         ######################################################################
664         # compute the readout
665
666         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
667
668         # We build tensors NxHxTxFxL where N is the sample index, H
669         # the head, T the time, F the row in the caterpillar, and L
670         # the column in the caterpillar
671
672         windowed_V = moving_window(
673             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
674         )
675
676         windowed_K = moving_window(
677             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
678         )
679
680         # We have an attention score for each of the RxL values
681
682         ar = torch.einsum(
683             "nhtd,nftld->nhtfl",
684             Q,
685             windowed_K,
686         ) / math.sqrt(DK)
687
688         # softmax can operate only on one dimension, hence the
689         # flattening
690
691         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
692
693         ar = F.dropout(ar, self.attention_dropout, self.training)
694
695         # Compute the output for each head, flatten to concatenate
696
697         Y = torch.einsum(
698             "nhtfl,nftld->nthd",
699             ar,
700             windowed_V,
701         ).flatten(2)
702
703         # Compute the final output
704
705         self.cache_Y[:, t0:t1] = Y @ self.w_O
706
707         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
708
709
710 ##############################
711
712
713 class QKVAttention(nn.Module):
714     def __init__(
715         self,
716         dim_model,
717         dim_qk,
718         dim_v,
719         nb_heads=1,
720         causal=False,
721         attention_dropout=0.0,
722         logger=print,
723         **kwargs,
724     ):
725         super().__init__()
726
727         def randw(*d):
728             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
729
730         self.causal = causal
731         self.attention_dropout = attention_dropout
732         self.record_attention = False
733
734         self.w_q = randw(nb_heads, dim_qk, dim_model)
735         self.w_k = randw(nb_heads, dim_qk, dim_model)
736         self.w_v = randw(nb_heads, dim_v, dim_model)
737         self.w_o = randw(dim_v * nb_heads, dim_model)
738
739     def forward(self, bs):
740         x_q = bs.x
741
742         assert (
743             self.causal or bs.complete()
744         ), "Partial evaluation is only possible for causal models"
745
746         if bs.init_cache:
747             self.cache_k = x_q.new_zeros(
748                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
749             )
750             self.cache_v = x_q.new_zeros(
751                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
752             )
753             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
754
755         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
756
757         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
758             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
759         )
760         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
761             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
762         )
763
764         a = torch.einsum(
765             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
766         ) / math.sqrt(self.w_q.size(1))
767
768         if self.causal:
769             if bs.init_cache:
770                 self.cache_attzero = (
771                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
772                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
773                 )
774             a = a.masked_fill(
775                 self.cache_attzero[
776                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
777                 ],
778                 float("-inf"),
779             )
780
781         a = a.softmax(dim=3)
782
783         if self.record_attention:
784             self.a = a
785
786         a = F.dropout(a, self.attention_dropout, self.training)
787
788         y = torch.einsum(
789             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
790         ).flatten(2)
791
792         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
793
794         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
795
796
797 ##############################
798
799
800 class MyGPT(nn.Module):
801     def __init__(
802         self,
803         vocabulary_size,
804         dim_model,
805         dim_keys,
806         dim_hidden,
807         nb_heads,
808         nb_blocks,
809         nb_lines=None,
810         caterpillar_height=None,
811         causal=False,
812         dropout=0.0,
813         len_max=1e5,
814         attention_layer="kvrec",
815         logger=print,
816         **kwargs,
817     ):
818         super().__init__()
819
820         assert attention_layer in {
821             "mha",
822             "dumbrec",
823             "kvrec",
824             "caterpillar",
825         }, f"Unknown attention operator {attention_layer}."
826
827         if attention_layer == "caterpillar":
828             assert nb_lines % caterpillar_height == 0
829             self.caterpillar_length = nb_lines // caterpillar_height
830             self.caterpillar_height = caterpillar_height
831         else:
832             self.caterpillar_length = -1
833             self.caterpillar_height = -1
834
835         assert dim_model % nb_heads == 0
836
837         self.embedding = nn.Sequential(
838             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
839             AddPositionalEncoding(len_max),
840         )
841
842         trunk_blocks = []
843
844         def attlayer():
845             if attention_layer == "mha":
846                 return QKVAttention(
847                     dim_model=dim_model,
848                     dim_qk=dim_keys,
849                     dim_v=dim_model // nb_heads,
850                     nb_heads=nb_heads,
851                     causal=causal,
852                     attention_dropout=dropout,
853                     logger=logger,
854                     **kwargs,
855                 )
856             elif attention_layer == "dumbrec":
857                 return DumbRec(
858                     dim_model=dim_model,
859                     dim_qk=dim_keys,
860                     dim_v=dim_model // nb_heads,
861                     nb_heads=nb_heads,
862                     nb_lines=nb_lines,
863                     attention_dropout=dropout,
864                     logger=logger,
865                     **kwargs,
866                 )
867             elif attention_layer == "kvrec":
868                 return KVRec(
869                     dim_model=dim_model,
870                     dim_qk=dim_keys,
871                     dim_v=dim_model // nb_heads,
872                     nb_heads=nb_heads,
873                     nb_lines=nb_lines,
874                     attention_dropout=dropout,
875                     logger=logger,
876                     **kwargs,
877                 )
878             elif attention_layer == "caterpillar":
879                 return Caterpillar(
880                     dim_model=dim_model,
881                     dim_qk=dim_keys,
882                     dim_v=dim_model // nb_heads,
883                     nb_heads=nb_heads,
884                     caterpillar_length=self.caterpillar_length,
885                     caterpillar_height=self.caterpillar_height,
886                     attention_dropout=dropout,
887                     logger=logger,
888                     **kwargs,
889                 )
890             else:
891                 raise ValueError(f"Unknown attention type {attention_layer}.")
892
893         for b in range(nb_blocks):
894             trunk_blocks += [
895                 WithResidual(
896                     CacheWrapper(nn.LayerNorm((dim_model,))),
897                     attlayer(),
898                 ),
899                 WithResidual(
900                     CacheWrapper(
901                         nn.LayerNorm((dim_model,)),
902                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
903                         nn.ReLU(),
904                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
905                         nn.Dropout(dropout),
906                     ),
907                 ),
908             ]
909
910         self.trunk = nn.Sequential(*trunk_blocks)
911
912         self.readout = CacheWrapper(
913             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
914         )
915
916         with torch.no_grad():
917             for m in self.modules():
918                 if isinstance(m, nn.Embedding):
919                     m.weight.normal_(mean=0, std=2e-2)
920                 elif isinstance(m, nn.LayerNorm):
921                     m.bias.zero_()
922                     m.weight.fill_(1.0)
923
924         self.reset_inner_loss()
925
926     def forward(self, bs):
927         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
928
929         # To make the code simpler in the Caterpillar layer, we pad
930         # here. It's unclear if/how much it hurts computationaly by
931         # increasing the sequence length for the other layers
932
933         if self.caterpillar_length > 0:
934             original_nb = bs.nb
935             if bs.nb % self.caterpillar_length > 0:
936                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
937
938             bs = BracketedSequence(
939                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
940                 bs.first + self.caterpillar_length,
941                 bs.nb,
942                 bs.init_cache,
943             )
944
945         bs = self.embedding(bs)
946         bs = self.trunk(bs)
947         bs = self.readout(bs)
948
949         if self.caterpillar_length > 0:
950             bs = BracketedSequence(
951                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
952                 bs.first - self.caterpillar_length,
953                 original_nb,
954                 bs.init_cache,
955             )
956
957         return bs
958
959     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
960     # 1s where tokens should be generated. The others are kept
961     # unchanged.
962
963     def masked_inplace_autoregression(
964         self,
965         input_src,
966         ar_mask_src,
967         forbidden_tokens=None,
968         deterministic_synthesis=False,
969     ):
970         input = input_src.to(self.readout.f.weight.device)
971         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
972         to_generate = (ar_mask.sum(0) > 0).nonzero()
973         if to_generate.min() > 0:
974             self(
975                 BracketedSequence(input, 0, to_generate.min(), True)
976             )  # Needed to initialize the model's cache
977         for s in range(to_generate.min(), to_generate.max() + 1):
978             output = self(BracketedSequence(input, s, 1, s == 0)).x
979             logits = output[:, s]
980             if forbidden_tokens is not None:
981                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
982             if deterministic_synthesis:
983                 t_next = logits.argmax(1)
984             else:
985                 dist = torch.distributions.categorical.Categorical(logits=logits)
986                 t_next = dist.sample()
987             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
988
989         input_src.copy_(input)
990
991     def reset_inner_loss(self):
992         for m in self.modules():
993             if m is not self and hasattr(m, "reset_inner_loss"):
994                 m.reset_inner_loss()
995
996     def get_inner_loss(self):
997         l = torch.tensor([0.0], device=self.readout.f.weight.device)
998         for m in self.modules():
999             if m is not self and hasattr(m, "get_inner_loss"):
1000                 l += m.get_inner_loss()
1001         return l
1002
1003     def record_attention(self, v=True):
1004         for m in self.modules():
1005             if isinstance(m, QKVAttention):
1006                 m.record_attention = v
1007
1008     def retrieve_attention(self):
1009         a = []
1010         for m in self.modules():
1011             if isinstance(m, QKVAttention):
1012                 a.append(m.a)
1013         return a
1014
1015
1016 ######################################################################
1017
1018 if __name__ == "__main__":
1019     print("Basic check.")
1020
1021     m = Caterpillar(
1022         dim_model=4,
1023         dim_qk=3,
1024         dim_v=7,
1025         nb_heads=1,
1026         caterpillar_length=7,
1027         caterpillar_height=3,
1028         attention_dropout=0.0,
1029     )
1030
1031     m.reset_inner_loss()
1032     x = torch.randn(1, 21 + 2 * 7, 4)
1033     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1034     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1035     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1036     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1037     print((y1 - y2).abs().max())
1038     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1039     exit(0)
1040
1041     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1042
1043     vocabulary_size = 128
1044     x = torch.randint(vocabulary_size, (6, 1024))
1045
1046     model = MyGPT(
1047         vocabulary_size=vocabulary_size,
1048         dim_model=512,
1049         dim_keys=64,
1050         dim_hidden=2048,
1051         nb_heads=8,
1052         nb_lines=128,
1053         nb_blocks=12,
1054         dropout=0.1,
1055         causal=True,
1056     )
1057
1058     x = x.to(device)
1059     model.to(device)
1060
1061     import time, sys
1062
1063     # import torchvision.models as models
1064     # from torch.profiler import profile, record_function, ProfilerActivity
1065
1066     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1067     # with record_function("model_inference"):
1068
1069     model.eval()
1070     for i in range(3):
1071         start_time = time.perf_counter()
1072         for k in range(10):
1073             model(BracketedSequence(x))
1074         duration = time.perf_counter() - start_time
1075         print(duration)
1076         sys.stdout.flush()
1077
1078     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1079     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1080
1081     # print("##############################################################")
1082     # y2 = torch.randn_like(y1)
1083     # for s in range(x.size(1)):
1084     # z = model(BracketedSequence(x, s, 1))
1085     # y2[:, s : s + 1] = z.slice()
1086
1087     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1088
1089 ######################################################################