Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d, amplitude=None):
480             if amplitude is None:
481                 amplitude = 1 / math.sqrt(d[-1])
482             return nn.Parameter(amplitude * torch.randn(*d))
483
484         self.caterpillar_length = caterpillar_length
485         self.caterpillar_height = caterpillar_height
486         self.attention_dropout = attention_dropout
487
488         self.proba_gate_dropout = 0.0
489
490         self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
491         self.b_G = nn.Parameter(
492             torch.full(
493                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
494             )
495         )
496
497         self.w_K = randw(nb_heads, dim_qk, dim_model)
498         self.w_V = randw(nb_heads, dim_v, dim_model)
499         self.w_Q = randw(nb_heads, dim_qk, dim_model)
500         self.w_O = randw(dim_v * nb_heads, dim_model)
501
502         self.init_K_rec = randw(
503             caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
504         )
505         self.init_V_rec = randw(
506             caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
507         )
508
509     def reset_inner_loss(self):
510         self.acc_attention = 0
511         self.acc_nb = 0
512
513     def get_inner_loss(self):
514         # warnings.warn("l2 regularization", RuntimeWarning)
515         # return (self.acc_attention / self.acc_nb).pow(2).sum()
516         return torch.tensor([0], device=self.w_Q.device)
517
518     def forward(self, bs):
519         # Dimensions to make the source a bit clearer, that's needed
520
521         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
522
523         N = bs.x.size(0)
524         T = bs.x.size(1)
525         H = self.w_V.size(0)
526         DV = self.w_V.size(1)
527         DK = self.w_K.size(1)
528         DM = self.w_O.size(1)
529         CH = self.caterpillar_height
530         CL = self.caterpillar_length
531
532         assert (
533             t0 >= CL and (t1 - t0) % CL == 0
534         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
535
536         # We cache values to deal efficiently with auto-regression
537
538         if bs.init_cache:
539             self.rec_V = X.new_zeros(N, CH, T, DV)
540             self.rec_K = X.new_zeros(N, CH, T, DK)
541             # We start the recurrent sequences with optimizable
542             # initial values. No idea if it helps.
543             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
544             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
545
546             self.cache_Y = X.new_zeros(N, T, DM)
547
548         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
549         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
550
551         ######################################################################
552         # Compute the recurrent state
553
554         # This is the Gating sequence that modulates the storing of
555         # the new key and value in the CH pairs of the current
556         # stack. There are CH independent gating values, which means
557         # that the current K/V may be stored in multiple pairs of the
558         # recurrent state, or not at all.
559
560         G = (
561             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
562         ).sigmoid()
563
564         ######################################################################
565         # The "flashbacks"
566
567         if self.training and self.proba_gate_dropout > 0.0:
568             # This is a better implementation of "flashbacks".
569
570             # G is NxHxExT where e is the caterpillar's row.
571
572             warnings.warn("gate dropout", RuntimeWarning)
573             epsilon = 0.5
574
575             dropout_start = (
576                 (
577                     torch.rand(G.size(), device=G.device)
578                     .flatten(2, 3)
579                     .sort(dim=2)
580                     .indices
581                     == 0
582                 )
583                 .unflatten(2, (CH, t1 - t0))
584                 .float()
585             )
586
587             dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
588
589             dropout_active = (
590                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
591             ).long()
592
593             dropout_start *= dropout_active
594             dropout_tail *= dropout_active
595
596             G = (
597                 G
598                 + dropout_start * (1 - epsilon - G.detach())
599                 - dropout_tail * G.detach()
600             )
601
602         ######################################################################
603
604         # We prepare the arguments for the parallel scan
605
606         # Clip the gating to avoid values greater than 1 when several
607         # heads hit the same row
608
609         G = G / G.sum(1, keepdim=True).clamp(min=1)
610
611         A = 1 - G.sum(1)
612         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
613         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
614
615         # We start from cached values, which matters in inference
616
617         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
618         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
619
620         #################################################################
621         # Associative scan
622
623         # Here there is a trick: Since the stack at position t is
624         # computed by updating that at position t-CL, the parallel
625         # scan operates with a period of CL. To do so we split the
626         # sequence indexing in two axes, the second of size CL, and
627         # run the parallel scan using the first as the sequence index.
628
629         A = A.unflatten(2, (-1, CL))
630         gated_V = gated_V.unflatten(2, (-1, CL))
631         gated_K = gated_K.unflatten(2, (-1, CL))
632
633         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
634         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
635
636         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
637         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
638
639         ######################################################################
640         # compute the readout
641
642         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
643
644         # We build tensors NxHxTxFxL where N is the sample index, H
645         # the head, T the time, F the row in the caterpillar, and L
646         # the column in the caterpillar
647
648         windowed_V = moving_window(
649             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
650         )
651
652         windowed_K = moving_window(
653             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
654         )
655
656         # We have an attention score for each of the CHxCL values
657
658         ar = torch.einsum(
659             "nhtd,nftld->nhtfl",
660             Q,
661             windowed_K,
662         ) / math.sqrt(DK)
663
664         # softmax can operate only on one dimension, hence the
665         # flattening
666
667         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
668
669         ar = F.dropout(ar, self.attention_dropout, self.training)
670
671         # Compute the output for each head, flatten to concatenate
672
673         Y = torch.einsum(
674             "nhtfl,nftld->nthd",
675             ar,
676             windowed_V,
677         ).flatten(2)
678
679         # Compute the final output
680
681         self.cache_Y[:, t0:t1] = Y @ self.w_O
682
683         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
684
685
686 ##############################
687
688
689 class QKVAttention(nn.Module):
690     def __init__(
691         self,
692         dim_model,
693         dim_qk,
694         dim_v,
695         nb_heads=1,
696         causal=False,
697         attention_dropout=0.0,
698     ):
699         super().__init__()
700
701         def randw(*d):
702             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
703
704         self.causal = causal
705         self.attention_dropout = attention_dropout
706         self.record_attention = False
707
708         self.w_q = randw(nb_heads, dim_qk, dim_model)
709         self.w_k = randw(nb_heads, dim_qk, dim_model)
710         self.w_v = randw(nb_heads, dim_v, dim_model)
711         self.w_o = randw(dim_v * nb_heads, dim_model)
712
713     def forward(self, bs):
714         x_q = bs.x
715
716         assert (
717             self.causal or bs.complete()
718         ), "Partial evaluation is only possible for causal models"
719
720         if bs.init_cache:
721             self.cache_k = x_q.new_zeros(
722                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
723             )
724             self.cache_v = x_q.new_zeros(
725                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
726             )
727             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
728
729         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
730
731         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
732             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
733         )
734         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
735             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
736         )
737
738         a = torch.einsum(
739             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
740         ) / math.sqrt(self.w_q.size(1))
741
742         if self.causal:
743             if bs.init_cache:
744                 self.cache_attzero = (
745                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
746                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
747                 )
748             a = a.masked_fill(
749                 self.cache_attzero[
750                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
751                 ],
752                 float("-inf"),
753             )
754
755         a = a.softmax(dim=3)
756
757         if self.record_attention:
758             self.a = a
759
760         a = F.dropout(a, self.attention_dropout, self.training)
761
762         y = torch.einsum(
763             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
764         ).flatten(2)
765
766         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
767
768         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
769
770
771 ##############################
772
773
774 class MyGPT(nn.Module):
775     def __init__(
776         self,
777         vocabulary_size,
778         dim_model,
779         dim_keys,
780         dim_hidden,
781         nb_heads,
782         nb_blocks,
783         nb_lines=None,
784         caterpillar_height=None,
785         causal=False,
786         dropout=0.0,
787         len_max=1e5,
788         attention_layer="kvrec",
789     ):
790         super().__init__()
791
792         assert attention_layer in {
793             "mha",
794             "dumbrec",
795             "kvrec",
796             "caterpillar",
797         }, f"Unknown attention operator {attention_layer}."
798
799         if attention_layer == "caterpillar":
800             assert nb_lines % caterpillar_height == 0
801             self.caterpillar_length = nb_lines // caterpillar_height
802             self.caterpillar_height = caterpillar_height
803         else:
804             self.caterpillar_length = -1
805             self.caterpillar_height = -1
806
807         assert dim_model % nb_heads == 0
808
809         self.embedding = nn.Sequential(
810             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
811             AddPositionalEncoding(len_max),
812         )
813
814         trunk_blocks = []
815
816         def attlayer():
817             if attention_layer == "mha":
818                 return QKVAttention(
819                     dim_model=dim_model,
820                     dim_qk=dim_keys,
821                     dim_v=dim_model // nb_heads,
822                     nb_heads=nb_heads,
823                     causal=causal,
824                     attention_dropout=dropout,
825                 )
826             elif attention_layer == "dumbrec":
827                 return DumbRec(
828                     dim_model=dim_model,
829                     dim_qk=dim_keys,
830                     dim_v=dim_model // nb_heads,
831                     nb_heads=nb_heads,
832                     nb_lines=nb_lines,
833                     attention_dropout=dropout,
834                 )
835             elif attention_layer == "kvrec":
836                 return KVRec(
837                     dim_model=dim_model,
838                     dim_qk=dim_keys,
839                     dim_v=dim_model // nb_heads,
840                     nb_heads=nb_heads,
841                     nb_lines=nb_lines,
842                     attention_dropout=dropout,
843                 )
844             elif attention_layer == "caterpillar":
845                 return Caterpillar(
846                     dim_model=dim_model,
847                     dim_qk=dim_keys,
848                     dim_v=dim_model // nb_heads,
849                     nb_heads=nb_heads,
850                     caterpillar_length=self.caterpillar_length,
851                     caterpillar_height=self.caterpillar_height,
852                     attention_dropout=dropout,
853                 )
854             else:
855                 raise ValueError(f"Unknown attention type {attention_layer}.")
856
857         for b in range(nb_blocks):
858             trunk_blocks += [
859                 WithResidual(
860                     CacheWrapper(nn.LayerNorm((dim_model,))),
861                     attlayer(),
862                 ),
863                 WithResidual(
864                     CacheWrapper(
865                         nn.LayerNorm((dim_model,)),
866                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
867                         nn.ReLU(),
868                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
869                         nn.Dropout(dropout),
870                     ),
871                 ),
872             ]
873
874         self.trunk = nn.Sequential(*trunk_blocks)
875
876         self.readout = CacheWrapper(
877             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
878         )
879
880         with torch.no_grad():
881             for m in self.modules():
882                 if isinstance(m, nn.Embedding):
883                     m.weight.normal_(mean=0, std=2e-2)
884                 elif isinstance(m, nn.LayerNorm):
885                     m.bias.zero_()
886                     m.weight.fill_(1.0)
887
888         self.reset_inner_loss()
889
890     def forward(self, bs):
891         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
892
893         # To make the code simpler in the Caterpillar layer, we pad
894         # here. It's unclear if/how much it hurts computationaly by
895         # increasing the sequence length for the other layers
896
897         if self.caterpillar_length > 0:
898             original_nb = bs.nb
899             if bs.nb % self.caterpillar_length > 0:
900                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
901
902             bs = BracketedSequence(
903                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
904                 bs.first + self.caterpillar_length,
905                 bs.nb,
906                 bs.init_cache,
907             )
908
909         bs = self.embedding(bs)
910         bs = self.trunk(bs)
911         bs = self.readout(bs)
912
913         if self.caterpillar_length > 0:
914             bs = BracketedSequence(
915                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
916                 bs.first - self.caterpillar_length,
917                 original_nb,
918                 bs.init_cache,
919             )
920
921         return bs
922
923     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
924     # 1s where tokens should be generated. The others are kept
925     # unchanged.
926
927     def masked_inplace_autoregression(
928         self,
929         input_src,
930         ar_mask_src,
931         forbidden_tokens=None,
932         deterministic_synthesis=False,
933     ):
934         input = input_src.to(self.readout.f.weight.device)
935         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
936         to_generate = (ar_mask.sum(0) > 0).nonzero()
937         if to_generate.min() > 0:
938             self(
939                 BracketedSequence(input, 0, to_generate.min(), True)
940             )  # Needed to initialize the model's cache
941         for s in range(to_generate.min(), to_generate.max() + 1):
942             output = self(BracketedSequence(input, s, 1, s == 0)).x
943             logits = output[:, s]
944             if forbidden_tokens is not None:
945                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
946             if deterministic_synthesis:
947                 t_next = logits.argmax(1)
948             else:
949                 dist = torch.distributions.categorical.Categorical(logits=logits)
950                 t_next = dist.sample()
951             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
952
953         input_src.copy_(input)
954
955     def reset_inner_loss(self):
956         for m in self.modules():
957             if m is not self and hasattr(m, "reset_inner_loss"):
958                 m.reset_inner_loss()
959
960     def get_inner_loss(self):
961         l = torch.tensor([0.0], device=self.readout.f.weight.device)
962         for m in self.modules():
963             if m is not self and hasattr(m, "get_inner_loss"):
964                 l += m.get_inner_loss()
965         return l
966
967     def record_attention(self, v=True):
968         for m in self.modules():
969             if isinstance(m, QKVAttention):
970                 m.record_attention = v
971
972     def retrieve_attention(self):
973         a = []
974         for m in self.modules():
975             if isinstance(m, QKVAttention):
976                 a.append(m.a)
977         return a
978
979
980 ######################################################################
981
982 if __name__ == "__main__":
983     print("Basic check.")
984
985     m = Caterpillar(
986         dim_model=4,
987         dim_qk=3,
988         dim_v=7,
989         nb_heads=1,
990         caterpillar_length=7,
991         caterpillar_height=3,
992         attention_dropout=0.0,
993     )
994
995     m.reset_inner_loss()
996     x = torch.randn(1, 21 + 2 * 7, 4)
997     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
998     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
999     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1000     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1001     print((y1 - y2).abs().max())
1002     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1003     exit(0)
1004
1005     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1006
1007     vocabulary_size = 128
1008     x = torch.randint(vocabulary_size, (6, 1024))
1009
1010     model = MyGPT(
1011         vocabulary_size=vocabulary_size,
1012         dim_model=512,
1013         dim_keys=64,
1014         dim_hidden=2048,
1015         nb_heads=8,
1016         nb_lines=128,
1017         nb_blocks=12,
1018         dropout=0.1,
1019         causal=True,
1020     )
1021
1022     x = x.to(device)
1023     model.to(device)
1024
1025     import time, sys
1026
1027     # import torchvision.models as models
1028     # from torch.profiler import profile, record_function, ProfilerActivity
1029
1030     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1031     # with record_function("model_inference"):
1032
1033     model.eval()
1034     for i in range(3):
1035         start_time = time.perf_counter()
1036         for k in range(10):
1037             model(BracketedSequence(x))
1038         duration = time.perf_counter() - start_time
1039         print(duration)
1040         sys.stdout.flush()
1041
1042     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1043     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1044
1045     # print("##############################################################")
1046     # y2 = torch.randn_like(y1)
1047     # for s in range(x.size(1)):
1048     # z = model(BracketedSequence(x, s, 1))
1049     # y2[:, s : s + 1] = z.slice()
1050
1051     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1052
1053 ######################################################################