Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d):
480             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
481
482         self.caterpillar_length = caterpillar_length
483         self.caterpillar_height = caterpillar_height
484         self.attention_dropout = attention_dropout
485
486         self.proba_gate_dropout = 0.0
487
488         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
489         self.b_G = nn.Parameter(
490             torch.full(
491                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
492             )
493         )
494
495         self.w_K = randw(nb_heads, dim_qk, dim_model)
496         self.w_V = randw(nb_heads, dim_v, dim_model)
497         self.w_Q = randw(nb_heads, dim_qk, dim_model)
498         self.w_O = randw(dim_v * nb_heads, dim_model)
499
500         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
501         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
502
503     def reset_inner_loss(self):
504         self.acc_attention = 0
505         self.acc_nb = 0
506
507     def get_inner_loss(self):
508         # warnings.warn("l2 regularization", RuntimeWarning)
509         # return (self.acc_attention / self.acc_nb).pow(2).sum()
510         return torch.tensor([0], device=self.w_Q.device)
511
512     def forward(self, bs):
513         # Dimensions to make the source a bit clearer, that's needed
514
515         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
516
517         N = bs.x.size(0)
518         T = bs.x.size(1)
519         H = self.w_V.size(0)
520         DV = self.w_V.size(1)
521         DK = self.w_K.size(1)
522         DM = self.w_O.size(1)
523         CH = self.caterpillar_height
524         CL = self.caterpillar_length
525
526         assert (
527             t0 >= CL and (t1 - t0) % CL == 0
528         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
529
530         # We cache values to deal efficiently with auto-regression
531
532         if bs.init_cache:
533             self.rec_V = X.new_zeros(N, CH, T, DV)
534             self.rec_K = X.new_zeros(N, CH, T, DK)
535             # We start the recurrent sequences with optimizable
536             # initial values. No idea if it helps.
537             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
538             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
539
540             self.cache_Y = X.new_zeros(N, T, DM)
541
542         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
543         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
544
545         ######################################################################
546         # Compute the recurrent state
547
548         # This is the Gating sequence that modulates the storing of
549         # the new key and value in the CH pairs of the current
550         # stack. There are CH independent gating values, which means
551         # that the current K/V may be stored in multiple pairs of the
552         # recurrent state, or not at all.
553
554         G = (
555             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
556         ).sigmoid()
557
558         # Clip the gating to avoid values greater than 1 when several
559         # heads hit the same row
560
561         G = G / G.sum(1, keepdim=True).clamp(min=1)
562
563         # We prepare the arguments for the parallel scan
564
565         A = 1 - G.sum(1)
566         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
567         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
568
569         # We start from cached values, which matters in inference
570
571         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
572         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
573
574         ######################################################################
575
576         if self.training and self.proba_gate_dropout > 0.0:
577             warnings.warn("gate dropout", RuntimeWarning)
578             epsilon = 0.5
579
580         #################################################################
581         # Associative scan
582
583         # Here there is a trick: Since the stack at position t is
584         # computed by updating that at position t-CL, the parallel
585         # scan operates with a period of CL. To do so we split the
586         # sequence indexing in two axes, the second of size CL, and
587         # run the parallel scan using the first as the sequence index.
588
589         A = A.unflatten(2, (-1, CL))
590         gated_V = gated_V.unflatten(2, (-1, CL))
591         gated_K = gated_K.unflatten(2, (-1, CL))
592
593         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
594         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
595
596         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
597         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
598
599         ######################################################################
600         # compute the readout
601
602         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
603
604         # We build tensors NxHxTxFxL where N is the sample index, H
605         # the head, T the time, F the row in the caterpillar, and L
606         # the column in the caterpillar
607
608         windowed_V = moving_window(
609             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
610         )
611
612         windowed_K = moving_window(
613             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
614         )
615
616         # We have an attention score for each of the CHxCL values
617
618         ar = torch.einsum(
619             "nhtd,nftld->nhtfl",
620             Q,
621             windowed_K,
622         ) / math.sqrt(DK)
623
624         # softmax can operate only on one dimension, hence the
625         # flattening
626
627         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
628
629         ar = F.dropout(ar, self.attention_dropout, self.training)
630
631         # Compute the output for each head, flatten to concatenate
632
633         Y = torch.einsum(
634             "nhtfl,nftld->nthd",
635             ar,
636             windowed_V,
637         ).flatten(2)
638
639         # Compute the final output
640
641         self.cache_Y[:, t0:t1] = Y @ self.w_O
642
643         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
644
645
646 ##############################
647
648
649 class QKVAttention(nn.Module):
650     def __init__(
651         self,
652         dim_model,
653         dim_qk,
654         dim_v,
655         nb_heads=1,
656         causal=False,
657         attention_dropout=0.0,
658     ):
659         super().__init__()
660
661         def randw(*d):
662             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
663
664         self.causal = causal
665         self.attention_dropout = attention_dropout
666         self.record_attention = False
667
668         self.w_q = randw(nb_heads, dim_qk, dim_model)
669         self.w_k = randw(nb_heads, dim_qk, dim_model)
670         self.w_v = randw(nb_heads, dim_v, dim_model)
671         self.w_o = randw(dim_v * nb_heads, dim_model)
672
673     def forward(self, bs):
674         x_q = bs.x
675
676         assert (
677             self.causal or bs.complete()
678         ), "Partial evaluation is only possible for causal models"
679
680         if bs.init_cache:
681             self.cache_k = x_q.new_zeros(
682                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
683             )
684             self.cache_v = x_q.new_zeros(
685                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
686             )
687             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
688
689         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
690
691         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
692             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
693         )
694         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
695             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
696         )
697
698         a = torch.einsum(
699             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
700         ) / math.sqrt(self.w_q.size(1))
701
702         if self.causal:
703             if bs.init_cache:
704                 self.cache_attzero = (
705                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
706                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
707                 )
708             a = a.masked_fill(
709                 self.cache_attzero[
710                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
711                 ],
712                 float("-inf"),
713             )
714
715         a = a.softmax(dim=3)
716
717         if self.record_attention:
718             self.a = a
719
720         a = F.dropout(a, self.attention_dropout, self.training)
721
722         y = torch.einsum(
723             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
724         ).flatten(2)
725
726         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
727
728         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
729
730
731 ##############################
732
733
734 class MyGPT(nn.Module):
735     def __init__(
736         self,
737         vocabulary_size,
738         dim_model,
739         dim_keys,
740         dim_hidden,
741         nb_heads,
742         nb_blocks,
743         nb_lines=None,
744         caterpillar_height=None,
745         causal=False,
746         dropout=0.0,
747         len_max=1e5,
748         attention_layer="kvrec",
749     ):
750         super().__init__()
751
752         assert attention_layer in {
753             "mha",
754             "dumbrec",
755             "kvrec",
756             "caterpillar",
757         }, f"Unknown attention operator {attention_layer}."
758
759         if attention_layer == "caterpillar":
760             assert nb_lines % caterpillar_height == 0
761             self.caterpillar_length = nb_lines // caterpillar_height
762             self.caterpillar_height = caterpillar_height
763         else:
764             self.caterpillar_length = -1
765             self.caterpillar_height = -1
766
767         assert dim_model % nb_heads == 0
768
769         self.embedding = nn.Sequential(
770             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
771             AddPositionalEncoding(len_max),
772         )
773
774         trunk_blocks = []
775
776         def attlayer():
777             if attention_layer == "mha":
778                 return QKVAttention(
779                     dim_model=dim_model,
780                     dim_qk=dim_keys,
781                     dim_v=dim_model // nb_heads,
782                     nb_heads=nb_heads,
783                     causal=causal,
784                     attention_dropout=dropout,
785                 )
786             elif attention_layer == "dumbrec":
787                 return DumbRec(
788                     dim_model=dim_model,
789                     dim_qk=dim_keys,
790                     dim_v=dim_model // nb_heads,
791                     nb_heads=nb_heads,
792                     nb_lines=nb_lines,
793                     attention_dropout=dropout,
794                 )
795             elif attention_layer == "kvrec":
796                 return KVRec(
797                     dim_model=dim_model,
798                     dim_qk=dim_keys,
799                     dim_v=dim_model // nb_heads,
800                     nb_heads=nb_heads,
801                     nb_lines=nb_lines,
802                     attention_dropout=dropout,
803                 )
804             elif attention_layer == "caterpillar":
805                 return Caterpillar(
806                     dim_model=dim_model,
807                     dim_qk=dim_keys,
808                     dim_v=dim_model // nb_heads,
809                     nb_heads=nb_heads,
810                     caterpillar_length=self.caterpillar_length,
811                     caterpillar_height=self.caterpillar_height,
812                     attention_dropout=dropout,
813                 )
814             else:
815                 raise ValueError(f"Unknown attention type {attention_layer}.")
816
817         for b in range(nb_blocks):
818             trunk_blocks += [
819                 WithResidual(
820                     CacheWrapper(nn.LayerNorm((dim_model,))),
821                     attlayer(),
822                 ),
823                 WithResidual(
824                     CacheWrapper(
825                         nn.LayerNorm((dim_model,)),
826                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
827                         nn.ReLU(),
828                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
829                         nn.Dropout(dropout),
830                     ),
831                 ),
832             ]
833
834         self.trunk = nn.Sequential(*trunk_blocks)
835
836         self.readout = CacheWrapper(
837             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
838         )
839
840         with torch.no_grad():
841             for m in self.modules():
842                 if isinstance(m, nn.Embedding):
843                     m.weight.normal_(mean=0, std=2e-2)
844                 elif isinstance(m, nn.LayerNorm):
845                     m.bias.zero_()
846                     m.weight.fill_(1.0)
847
848         self.reset_inner_loss()
849
850     def forward(self, bs):
851         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
852
853         # To make the code simpler in the Caterpillar layer, we pad
854         # here. It's unclear if/how much it hurts computationaly by
855         # increasing the sequence length for the other layers
856
857         if self.caterpillar_length > 0:
858             original_nb = bs.nb
859             if bs.nb % self.caterpillar_length > 0:
860                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
861
862             bs = BracketedSequence(
863                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
864                 bs.first + self.caterpillar_length,
865                 bs.nb,
866                 bs.init_cache,
867             )
868
869         bs = self.embedding(bs)
870         bs = self.trunk(bs)
871         bs = self.readout(bs)
872
873         if self.caterpillar_length > 0:
874             bs = BracketedSequence(
875                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
876                 bs.first - self.caterpillar_length,
877                 original_nb,
878                 bs.init_cache,
879             )
880
881         return bs
882
883     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
884     # 1s where tokens should be generated. The others are kept
885     # unchanged.
886
887     def masked_inplace_autoregression(
888         self,
889         input_src,
890         ar_mask_src,
891         forbidden_tokens=None,
892         deterministic_synthesis=False,
893     ):
894         input = input_src.to(self.readout.f.weight.device)
895         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
896         to_generate = (ar_mask.sum(0) > 0).nonzero()
897         if to_generate.min() > 0:
898             self(
899                 BracketedSequence(input, 0, to_generate.min(), True)
900             )  # Needed to initialize the model's cache
901         for s in range(to_generate.min(), to_generate.max() + 1):
902             output = self(BracketedSequence(input, s, 1, s == 0)).x
903             logits = output[:, s]
904             if forbidden_tokens is not None:
905                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
906             if deterministic_synthesis:
907                 t_next = logits.argmax(1)
908             else:
909                 dist = torch.distributions.categorical.Categorical(logits=logits)
910                 t_next = dist.sample()
911             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
912
913         input_src.copy_(input)
914
915     def reset_inner_loss(self):
916         for m in self.modules():
917             if m is not self and hasattr(m, "reset_inner_loss"):
918                 m.reset_inner_loss()
919
920     def get_inner_loss(self):
921         l = torch.tensor([0.0], device=self.readout.f.weight.device)
922         for m in self.modules():
923             if m is not self and hasattr(m, "get_inner_loss"):
924                 l += m.get_inner_loss()
925         return l
926
927     def record_attention(self, v=True):
928         for m in self.modules():
929             if isinstance(m, QKVAttention):
930                 m.record_attention = v
931
932     def retrieve_attention(self):
933         a = []
934         for m in self.modules():
935             if isinstance(m, QKVAttention):
936                 a.append(m.a)
937         return a
938
939
940 ######################################################################
941
942 if __name__ == "__main__":
943     print("Basic check.")
944
945     m = Caterpillar(
946         dim_model=4,
947         dim_qk=3,
948         dim_v=7,
949         nb_heads=1,
950         caterpillar_length=7,
951         caterpillar_height=3,
952         attention_dropout=0.0,
953     )
954
955     m.reset_inner_loss()
956     x = torch.randn(1, 21 + 2 * 7, 4)
957     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
958     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
959     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
960     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
961     print((y1 - y2).abs().max())
962     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
963     exit(0)
964
965     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
966
967     vocabulary_size = 128
968     x = torch.randint(vocabulary_size, (6, 1024))
969
970     model = MyGPT(
971         vocabulary_size=vocabulary_size,
972         dim_model=512,
973         dim_keys=64,
974         dim_hidden=2048,
975         nb_heads=8,
976         nb_lines=128,
977         nb_blocks=12,
978         dropout=0.1,
979         causal=True,
980     )
981
982     x = x.to(device)
983     model.to(device)
984
985     import time, sys
986
987     # import torchvision.models as models
988     # from torch.profiler import profile, record_function, ProfilerActivity
989
990     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
991     # with record_function("model_inference"):
992
993     model.eval()
994     for i in range(3):
995         start_time = time.perf_counter()
996         for k in range(10):
997             model(BracketedSequence(x))
998         duration = time.perf_counter() - start_time
999         print(duration)
1000         sys.stdout.flush()
1001
1002     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1003     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1004
1005     # print("##############################################################")
1006     # y2 = torch.randn_like(y1)
1007     # for s in range(x.size(1)):
1008     # z = model(BracketedSequence(x, s, 1))
1009     # y2[:, s : s + 1] = z.slice()
1010
1011     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1012
1013 ######################################################################