Initial commit
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19 from functorch.dim import dims
20
21 import ffutils
22
23 # import memload
24
25 ######################################################################
26
27 # A BracketedSequence is a BxTx... tensor with a first and a nb time
28 # steps to compute.
29
30 # Modules able to process it expect that they will have to process a
31 # first bracket starting at t=0, followed by a succession of brackets
32 # that move forward in time, do not overlap, and cover the axis T with
33 # no holes.
34 #
35 # Although it is more general, for a classical prompt-conditioned
36 # auto-regressive process it will be a first bracket starting at 0 and
37 # of arbitrary length for the "prompt", followed by brackets of length
38 # 1 for the successive tokens.
39 #
40 # Modules able to process brackets may implement a cache that is
41 # resetted when the input bracket starts at t=0
42
43
44 class BracketedSequence:
45     def __init__(self, x, first=None, nb=None, init_cache=None):
46         self.x = x
47         assert (first is None and nb is None and init_cache is None) or (
48             first is not None and nb is not None and init_cache is not None
49         )
50
51         self.first = 0 if first is None else first
52         self.nb = x.size(1) if nb is None else nb
53         self.init_cache = True if init_cache is None else init_cache
54
55     def slice(self):
56         return self.x[:, self.first : self.first + self.nb]
57
58     def complete(self):
59         return self.first == 0 and self.nb == self.x.size(1)
60
61
62 ######################################################################
63
64
65 class CacheWrapper(nn.Module):
66     def __init__(self, *f):
67         super().__init__()
68         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
69
70     def forward(self, bs):
71         if bs.init_cache:
72             y = self.f(bs.slice())
73             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
74             self.cache_y[:, bs.first : bs.first + bs.nb] = y
75         else:
76             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
77             assert bs.first + bs.nb <= self.cache_y.size(1)
78             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
79
80         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
81
82
83 ##############################
84
85
86 class WithResidual(nn.Module):
87     def __init__(self, *f):
88         super().__init__()
89         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
90
91     def forward(self, bs):
92         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
93
94
95 ##############################
96
97
98 class AddPositionalEncoding(nn.Module):
99     def __init__(self, len_max):
100         super().__init__()
101         self.len_max = len_max
102
103     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
104
105     def forward(self, bs):
106         if bs.init_cache:
107             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
108                 :, None
109             ]
110             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
111                 None, :
112             ]
113             k = j % 2
114             self.pe = torch.sin(
115                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
116             )
117             self.cache_y = bs.x.new(bs.x.size())
118
119         self.cache_y[:, bs.first : bs.first + bs.nb] = (
120             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
121         )
122
123         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
124
125
126 import pscan
127
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_shape(A, X, Y_init):
150     s = X.size()
151     A = A.reshape(-1, s[-2])
152     X = X.reshape(-1, s[-2], s[-1])
153
154     if Y_init is None:
155         Y_init = X.new_zeros(X.size(0), s[-1])
156     else:
157         Y_init = Y_init.reshape(-1, s[-1])
158
159     Y = pscan.pscan(A, X, Y_init).reshape(s)
160
161     return Y
162
163
164 def nsum_shape(X, Y_init):
165     s = X.size()
166     X = X.reshape(-1, s[-2], s[-1])  # ntd
167
168     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
169     result = []
170
171     for k in range(X.size(1)):
172         Y = Y + X[:, k]
173         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
174         result.append(Y)
175
176     return torch.cat(result, dim=1).reshape(s)
177
178
179 ##############################
180
181
182 class DumbRec(nn.Module):
183     def __init__(
184         self,
185         dim_in,
186         dim_qk,
187         dim_v,
188         nb_heads,
189         nb_lines,
190         attention_dropout=0.0,
191         len_max=1e5,
192     ):
193         super().__init__()
194
195         def randw(*d):
196             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
197
198         self.nb_lines = nb_lines
199         self.attention_dropout = attention_dropout
200
201         self.k_star = randw(nb_lines, dim_qk)
202
203         self.w_qw = randw(nb_heads, dim_qk, dim_in)
204         self.w_qr = randw(nb_heads, dim_qk, dim_in)
205         # self.w_k = randw(nb_heads, dim_qk, dim_in)
206         self.w_v = randw(nb_heads, dim_v, dim_in)
207         self.w_o = randw(dim_v * nb_heads, dim_in)
208
209     def reset_inner_loss(self):
210         self.acc_attention = 0
211         self.acc_nb = 0
212
213     def get_inner_loss(self):
214         warnings.warn("l2 regularization", RuntimeWarning)
215         return (self.acc_attention / self.acc_nb).pow(2).sum()
216         # return torch.tensor([0], device=self.w_qw.device)
217
218     def forward(self, bs):
219         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
220
221         if bs.init_cache:
222             self.rec_v = x_q.new_zeros(
223                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
224             )
225             # self.rec_k = x_q.new_zeros(
226             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
227             # )
228             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
229
230         ######################################################################
231         # Prepare the keys
232
233         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
234
235         warnings.warn("rotating key barrel", RuntimeWarning)
236         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
237         t_barrel = torch.arange(t0, t1, device=k_star.device)
238         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
239         l_barrel = (
240             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
241         ) % k_star.size(0)
242         k_star = k_star[l_barrel, t_barrel]
243
244         ######################################################################
245         # Compute the recurrent state
246
247         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
248
249         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
250         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
251
252         aw = torch.einsum(
253             "nhtd,ltd->nhlt",
254             qw,
255             k_star,
256         ) / math.sqrt(self.w_qw.size(1))
257
258         aw = aw.softmax(dim=2)  # nhlt
259
260         if self.train:
261             self.acc_attention += aw.sum(dim=(0, 1, 3))
262             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
263
264         aw = F.dropout(aw, self.attention_dropout, self.training)
265
266         A = 1 - aw.sum(dim=1)  # nlt
267
268         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
269         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
270
271         if t0 == 0:
272             V0 = None
273             # K0 = None
274         else:
275             V0 = self.rec_v[:, :, t0 - 1]
276             # K0 = self.rec_k[:, :, t0 - 1]
277
278         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
279         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
280
281         ######################################################################
282         # compute the readout
283
284         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
285
286         ar = torch.einsum(
287             "nhtd,ld->nhlt",
288             qr,
289             # self.rec_k[:, :, t0:t1],
290             self.k_star,
291         ) / math.sqrt(self.w_qr.size(1))
292
293         ar = ar.softmax(dim=2)  # nhlt
294
295         ar = F.dropout(ar, self.attention_dropout, self.training)
296
297         y = torch.einsum(
298             "nhlt,nltd->nthd",
299             ar,
300             self.rec_v[:, :, t0:t1],
301         ).flatten(2)
302
303         self.cache_y[:, t0:t1] = y @ self.w_o
304
305         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
306
307
308 ##############################
309
310
311 class KVRec(nn.Module):
312     def __init__(
313         self,
314         dim_in,
315         dim_qk,
316         dim_v,
317         nb_heads,
318         nb_lines,
319         attention_dropout=0.0,
320         len_max=1e5,
321     ):
322         super().__init__()
323
324         def randw(*d):
325             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
326
327         self.nb_lines = nb_lines
328         self.attention_dropout = attention_dropout
329
330         self.k_star = randw(nb_lines, dim_qk)
331
332         self.w_qw = randw(nb_heads, dim_qk, dim_in)
333         self.w_qr = randw(nb_heads, dim_qk, dim_in)
334         self.w_k = randw(nb_heads, dim_qk, dim_in)
335         self.w_v = randw(nb_heads, dim_v, dim_in)
336         self.w_o = randw(dim_v * nb_heads, dim_in)
337
338     def reset_inner_loss(self):
339         self.acc_attention = 0
340         self.acc_nb = 0
341
342     def get_inner_loss(self):
343         warnings.warn("l2 regularization", RuntimeWarning)
344         return (self.acc_attention / self.acc_nb).pow(2).sum()
345         # return torch.tensor([0], device=self.w_qw.device)
346         # warnings.warn("side regularization", RuntimeWarning)
347         # return (
348         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
349         # )
350         # return torch.tensor([0], device=self.w_qw.device)
351
352     def forward(self, bs):
353         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
354
355         # n,h,l,t,d = dims(5)
356
357         if bs.init_cache:
358             self.rec_v = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
360             )
361             self.rec_k = x_q.new_zeros(
362                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
363             )
364             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
365
366         ######################################################################
367         # Prepare the keys
368
369         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
370
371         warnings.warn("rotating key barrel", RuntimeWarning)
372         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
373         t_barrel = torch.arange(t0, t1, device=k_star.device)
374         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
375         l_barrel = (
376             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
377         ) % k_star.size(0)
378         k_star = k_star[l_barrel, t_barrel]
379
380         ######################################################################
381         # Compute the recurrent state
382
383         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
384
385         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
386         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
387
388         aw = torch.einsum(
389             "nhtd,ltd->nhlt",
390             qw,
391             k_star,
392         ) / math.sqrt(self.w_qw.size(1))
393
394         aw = aw.softmax(dim=2)  # nhlt
395
396         if self.train:
397             # We want all the memory lines to be used similarly
398             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
399             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
400
401         aw = F.dropout(aw, self.attention_dropout, self.training)
402
403         A = 1 - aw.sum(dim=1)  # nlt
404
405         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
406         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
407
408         if t0 == 0:
409             V0 = None
410             K0 = None
411         else:
412             V0 = self.rec_v[:, :, t0 - 1]
413             K0 = self.rec_k[:, :, t0 - 1]
414
415         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
416         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
417
418         ######################################################################
419         # compute the readout
420
421         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
422
423         ar = torch.einsum(
424             "nhtd,nltd->nhlt",
425             qr,
426             self.rec_k[:, :, t0:t1],
427         ) / math.sqrt(self.w_qr.size(1))
428
429         ar = ar.softmax(dim=2)  # nhlt
430
431         ar = F.dropout(ar, self.attention_dropout, self.training)
432
433         y = torch.einsum(
434             "nhlt,nltd->nthd",
435             ar,
436             self.rec_v[:, :, t0:t1],
437         ).flatten(2)
438
439         self.cache_y[:, t0:t1] = y @ self.w_o
440
441         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
442
443
444 ##############################
445
446
447 def moving_window(x, dim, win_dim, win_size):
448     size, stride = x.size(), x.stride()
449     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
450     size = size[:win_dim] + (win_size,) + size[win_dim:]
451     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
452
453     return x.as_strided(size=size, stride=stride)
454
455
456 ##############################
457
458
459 class Caterpillar(nn.Module):
460     def __init__(
461         self,
462         dim_in,
463         dim_qk,
464         dim_v,
465         nb_heads,
466         caterpillar_length,
467         caterpillar_height,
468         attention_dropout=0.0,
469         len_max=1e5,
470     ):
471         super().__init__()
472
473         warnings.warn("Caterpillar", RuntimeWarning)
474
475         def randw(*d):
476             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
477
478         self.caterpillar_length = caterpillar_length
479         self.caterpillar_height = caterpillar_height
480         self.attention_dropout = attention_dropout
481
482         self.w_G = randw(nb_heads, caterpillar_height, dim_in)
483         self.b_G = nn.Parameter(
484             torch.full(
485                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
486             )
487         )
488
489         self.w_K = randw(nb_heads, dim_qk, dim_in)
490         self.w_V = randw(nb_heads, dim_v, dim_in)
491         self.w_Q = randw(nb_heads, dim_qk, dim_in)
492         self.w_O = randw(dim_v * nb_heads, dim_in)
493
494         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
495         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
496
497     def reset_inner_loss(self):
498         self.acc_attention = 0
499         self.acc_nb = 0
500
501     def get_inner_loss(self):
502         # warnings.warn("l2 regularization", RuntimeWarning)
503         # return (self.acc_attention / self.acc_nb).pow(2).sum()
504         return torch.tensor([0], device=self.w_Q.device)
505
506     def forward(self, bs):
507         # Dimensions to make the source a bit clearer, that's needed
508
509         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
510
511         N = bs.x.size(0)
512         T = bs.x.size(1)
513         DV = self.w_V.size(1)
514         DK = self.w_K.size(1)
515         Dout = self.w_O.size(1)
516         CH = self.caterpillar_height
517         CL = self.caterpillar_length
518
519         assert (
520             t0 >= CL and (t1 - t0) % CL == 0
521         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
522
523         if bs.init_cache:
524             self.rec_V = X.new_zeros(N, CH, T, DV)
525             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
526             self.rec_K = X.new_zeros(N, CH, T, DK)
527             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
528             self.cache_Y = X.new_zeros(N, T, Dout)
529
530         ######################################################################
531         # Compute the recurrent state
532
533         G = (
534             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
535         ).sigmoid()
536
537         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
538         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
539
540         A = 1 - G.sum(1)
541         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
542         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
543
544         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
545         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
546
547         A = A.unflatten(2, (-1, CL))
548         gated_V = gated_V.unflatten(2, (-1, CL))
549         gated_K = gated_K.unflatten(2, (-1, CL))
550
551         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
552         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
553
554         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
555         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
556
557         ######################################################################
558         # compute the readout
559
560         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
561
562         uv = moving_window(
563             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
564         )
565
566         uk = moving_window(
567             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
568         )
569
570         ar = torch.einsum(
571             "nhtd,nftld->nhtfl",
572             Q,
573             uk,
574         ) / math.sqrt(DK)
575
576         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
577
578         ar = F.dropout(ar, self.attention_dropout, self.training)
579
580         Y = torch.einsum(
581             "nhtfl,nftld->nthd",
582             ar,
583             uv,
584         ).flatten(2)
585
586         self.cache_Y[:, t0:t1] = Y @ self.w_O
587
588         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
589
590
591 ##############################
592
593
594 class QKVAttention(nn.Module):
595     def __init__(
596         self,
597         dim_in,
598         dim_qk,
599         dim_v,
600         nb_heads=1,
601         causal=False,
602         attention_dropout=0.0,
603     ):
604         super().__init__()
605
606         def randw(*d):
607             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
608
609         self.causal = causal
610         self.attention_dropout = attention_dropout
611         self.record_attention = False
612
613         self.w_q = randw(nb_heads, dim_qk, dim_in)
614         self.w_k = randw(nb_heads, dim_qk, dim_in)
615         self.w_v = randw(nb_heads, dim_v, dim_in)
616         self.w_o = randw(dim_v * nb_heads, dim_in)
617
618     def forward(self, bs):
619         x_q = bs.x
620
621         assert (
622             self.causal or bs.complete()
623         ), "Partial evaluation is only possible for causal models"
624
625         if bs.init_cache:
626             self.cache_k = x_q.new_zeros(
627                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
628             )
629             self.cache_v = x_q.new_zeros(
630                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
631             )
632             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
633
634         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
635
636         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
637             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
638         )
639         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
640             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
641         )
642
643         a = torch.einsum(
644             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
645         ) / math.sqrt(self.w_q.size(1))
646
647         if self.causal:
648             if bs.init_cache:
649                 self.cache_attzero = (
650                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
651                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
652                 )
653             a = a.masked_fill(
654                 self.cache_attzero[
655                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
656                 ],
657                 float("-inf"),
658             )
659
660         a = a.softmax(dim=3)
661
662         if self.record_attention:
663             self.a = a
664
665         a = F.dropout(a, self.attention_dropout, self.training)
666
667         y = torch.einsum(
668             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
669         ).flatten(2)
670
671         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
672
673         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
674
675
676 ##############################
677
678
679 class MyGPT(nn.Module):
680     def __init__(
681         self,
682         vocabulary_size,
683         dim_model,
684         dim_keys,
685         dim_hidden,
686         nb_heads,
687         nb_blocks,
688         nb_lines=None,
689         caterpillar_height=None,
690         dim_rec_v=-1,
691         causal=False,
692         dropout=0.0,
693         len_max=1e5,
694         attention_layer="kvrec",
695     ):
696         super().__init__()
697
698         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
699
700         if attention_layer == "caterpillar":
701             assert nb_lines % caterpillar_height == 0
702             self.caterpillar_length = nb_lines // caterpillar_height
703             self.caterpillar_height = caterpillar_height
704         else:
705             self.caterpillar_length = -1
706             self.caterpillar_height = -1
707
708         assert dim_model % nb_heads == 0
709
710         self.embedding = nn.Sequential(
711             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
712             AddPositionalEncoding(len_max),
713         )
714
715         trunk_blocks = []
716
717         def attlayer():
718             if attention_layer == "mha":
719                 return QKVAttention(
720                     dim_in=dim_model,
721                     dim_qk=dim_keys,
722                     dim_v=dim_model // nb_heads,
723                     nb_heads=nb_heads,
724                     causal=causal,
725                     attention_dropout=dropout,
726                 )
727             elif attention_layer == "dumbrec":
728                 return DumbRec(
729                     dim_in=dim_model,
730                     dim_qk=dim_keys,
731                     dim_v=dim_rec_v,
732                     nb_heads=nb_heads,
733                     nb_lines=nb_lines,
734                     attention_dropout=dropout,
735                 )
736             elif attention_layer == "kvrec":
737                 return KVRec(
738                     dim_in=dim_model,
739                     dim_qk=dim_keys,
740                     dim_v=dim_rec_v,
741                     nb_heads=nb_heads,
742                     nb_lines=nb_lines,
743                     attention_dropout=dropout,
744                 )
745             elif attention_layer == "caterpillar":
746                 return Caterpillar(
747                     dim_in=dim_model,
748                     dim_qk=dim_keys,
749                     dim_v=dim_rec_v,
750                     nb_heads=nb_heads,
751                     caterpillar_length=self.caterpillar_length,
752                     caterpillar_height=self.caterpillar_height,
753                     attention_dropout=dropout,
754                 )
755             else:
756                 raise ValueError(f"Unknown attention type {attention_layer}.")
757
758         for b in range(nb_blocks):
759             trunk_blocks += [
760                 WithResidual(
761                     CacheWrapper(nn.LayerNorm((dim_model,))),
762                     attlayer(),
763                 ),
764                 WithResidual(
765                     CacheWrapper(
766                         nn.LayerNorm((dim_model,)),
767                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
768                         nn.ReLU(),
769                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
770                         nn.Dropout(dropout),
771                     ),
772                 ),
773             ]
774
775         self.trunk = nn.Sequential(*trunk_blocks)
776
777         self.readout = CacheWrapper(
778             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
779         )
780
781         with torch.no_grad():
782             for m in self.modules():
783                 if isinstance(m, nn.Embedding):
784                     m.weight.normal_(mean=0, std=2e-2)
785                 elif isinstance(m, nn.LayerNorm):
786                     m.bias.zero_()
787                     m.weight.fill_(1.0)
788
789         self.reset_inner_loss()
790
791     def forward(self, bs):
792         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
793
794         # To make the code simpler in the Caterpillar layer, we pad
795         # here. It's unclear if/how much it hurts computationaly by
796         # increasing the sequence length for the other layers
797
798         if self.caterpillar_length > 0:
799             original_nb = bs.nb
800             if bs.nb % self.caterpillar_length > 0:
801                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
802
803             bs = BracketedSequence(
804                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
805                 bs.first + self.caterpillar_length,
806                 bs.nb,
807                 bs.init_cache,
808             )
809
810         bs = self.embedding(bs)
811         bs = self.trunk(bs)
812         bs = self.readout(bs)
813
814         if self.caterpillar_length > 0:
815             bs = BracketedSequence(
816                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
817                 bs.first - self.caterpillar_length,
818                 original_nb,
819                 bs.init_cache,
820             )
821
822         return bs
823
824     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
825     # 1s where tokens should be generated. The others are kept
826     # unchanged.
827
828     def masked_inplace_autoregression(
829         self,
830         input_src,
831         ar_mask_src,
832         forbidden_tokens=None,
833         deterministic_synthesis=False,
834     ):
835         input = input_src.to(self.readout.f.weight.device)
836         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
837         to_generate = (ar_mask.sum(0) > 0).nonzero()
838         if to_generate.min() > 0:
839             self(
840                 BracketedSequence(input, 0, to_generate.min(), True)
841             )  # Needed to initialize the model's cache
842         for s in range(to_generate.min(), to_generate.max() + 1):
843             output = self(BracketedSequence(input, s, 1, s == 0)).x
844             logits = output[:, s]
845             if forbidden_tokens is not None:
846                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
847             if deterministic_synthesis:
848                 t_next = logits.argmax(1)
849             else:
850                 dist = torch.distributions.categorical.Categorical(logits=logits)
851                 t_next = dist.sample()
852             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
853
854         input_src.copy_(input)
855
856     def reset_inner_loss(self):
857         for m in self.modules():
858             if m is not self and hasattr(m, "reset_inner_loss"):
859                 m.reset_inner_loss()
860
861     def get_inner_loss(self):
862         l = torch.tensor([0.0], device=self.readout.f.weight.device)
863         for m in self.modules():
864             if m is not self and hasattr(m, "get_inner_loss"):
865                 l += m.get_inner_loss()
866         return l
867
868     def record_attention(self, v=True):
869         for m in self.modules():
870             if isinstance(m, QKVAttention):
871                 m.record_attention = v
872
873     def retrieve_attention(self):
874         a = []
875         for m in self.modules():
876             if isinstance(m, QKVAttention):
877                 a.append(m.a)
878         return a
879
880
881 ######################################################################
882
883 if __name__ == "__main__":
884     print("Basic check.")
885
886     m = Caterpillar(
887         dim_in=4,
888         dim_qk=3,
889         dim_v=7,
890         nb_heads=1,
891         caterpillar_length=7,
892         caterpillar_height=3,
893         attention_dropout=0.0,
894     )
895
896     m.reset_inner_loss()
897     x = torch.randn(1, 21 + 2 * 7, 4)
898     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
899     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
900     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
901     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
902     print((y1 - y2).abs().max())
903     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
904     exit(0)
905
906     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
907
908     vocabulary_size = 128
909     x = torch.randint(vocabulary_size, (6, 1024))
910
911     model = MyGPT(
912         vocabulary_size=vocabulary_size,
913         dim_model=512,
914         dim_keys=64,
915         dim_hidden=2048,
916         nb_heads=8,
917         nb_lines=128,
918         nb_blocks=12,
919         dropout=0.1,
920         causal=True,
921     )
922
923     x = x.to(device)
924     model.to(device)
925
926     import time, sys
927
928     # import torchvision.models as models
929     # from torch.profiler import profile, record_function, ProfilerActivity
930
931     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
932     # with record_function("model_inference"):
933
934     model.eval()
935     for i in range(3):
936         start_time = time.perf_counter()
937         for k in range(10):
938             model(BracketedSequence(x))
939         duration = time.perf_counter() - start_time
940         print(duration)
941         sys.stdout.flush()
942
943     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
944     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
945
946     # print("##############################################################")
947     # y2 = torch.randn_like(y1)
948     # for s in range(x.size(1)):
949     # z = model(BracketedSequence(x, s, 1))
950     # y2[:, s : s + 1] = z.slice()
951
952     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
953
954 ######################################################################