Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_model,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_model)
203         self.w_qr = randw(nb_heads, dim_qk, dim_model)
204         # self.w_k = randw(nb_heads, dim_qk, dim_model)
205         self.w_v = randw(nb_heads, dim_v, dim_model)
206         self.w_o = randw(dim_v * nb_heads, dim_model)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_model,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_model)
332         self.w_qr = randw(nb_heads, dim_qk, dim_model)
333         self.w_k = randw(nb_heads, dim_qk, dim_model)
334         self.w_v = randw(nb_heads, dim_v, dim_model)
335         self.w_o = randw(dim_v * nb_heads, dim_model)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
447
448
449 def moving_window(x, dim, win_dim, win_size):
450     size, stride = x.size(), x.stride()
451     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452     size = size[:win_dim] + (win_size,) + size[win_dim:]
453     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
454
455     return x.as_strided(size=size, stride=stride)
456
457
458 ##############################
459
460
461 class Caterpillar(nn.Module):
462     def __init__(
463         self,
464         dim_model,
465         dim_qk,
466         dim_v,
467         nb_heads,
468         caterpillar_length,
469         caterpillar_height,
470         attention_dropout=0.0,
471         len_max=1e5,
472     ):
473         super().__init__()
474
475         warnings.warn("Caterpillar", RuntimeWarning)
476
477         def randw(*d):
478             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
479
480         self.caterpillar_length = caterpillar_length
481         self.caterpillar_height = caterpillar_height
482         self.attention_dropout = attention_dropout
483
484         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
485         self.b_G = nn.Parameter(
486             torch.full(
487                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
488             )
489         )
490
491         self.w_K = randw(nb_heads, dim_qk, dim_model)
492         self.w_V = randw(nb_heads, dim_v, dim_model)
493         self.w_Q = randw(nb_heads, dim_qk, dim_model)
494         self.w_O = randw(dim_v * nb_heads, dim_model)
495
496         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
497         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
498
499     def reset_inner_loss(self):
500         self.acc_attention = 0
501         self.acc_nb = 0
502
503     def get_inner_loss(self):
504         # warnings.warn("l2 regularization", RuntimeWarning)
505         # return (self.acc_attention / self.acc_nb).pow(2).sum()
506         return torch.tensor([0], device=self.w_Q.device)
507
508     def forward(self, bs):
509         # Dimensions to make the source a bit clearer, that's needed
510
511         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
512
513         N = bs.x.size(0)
514         T = bs.x.size(1)
515         DV = self.w_V.size(1)
516         DK = self.w_K.size(1)
517         DM = self.w_O.size(1)
518         CH = self.caterpillar_height
519         CL = self.caterpillar_length
520
521         assert (
522             t0 >= CL and (t1 - t0) % CL == 0
523         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
524
525         # We cache values to deal efficiently with auto-regression
526
527         if bs.init_cache:
528             self.rec_V = X.new_zeros(N, CH, T, DV)
529             self.rec_K = X.new_zeros(N, CH, T, DK)
530             # We start the recurrent sequences with optimizable
531             # initial values. No idea if it helps.
532             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
533             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
534
535             self.cache_Y = X.new_zeros(N, T, DM)
536
537         ######################################################################
538         # Compute the recurrent state
539
540         # This is the Gating sequence that modulates the storing of
541         # the new key and value in the CH pairs of the current
542         # stack. The CH gating values are independent, which means
543         # that the current K/V could be stored in all the pairs of the
544         # recurrent state, or not at all.
545
546         G = (
547             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
548         ).sigmoid()
549
550         G = F.dropout(G, self.attention_dropout, self.training)
551
552         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
553         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
554
555         # We prepare the arguments for the parallel scan
556
557         A = 1 - G.sum(1)
558         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
559         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
560
561         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
562         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
563
564         # Here there is a trick: Since the stack at time t is computed
565         # by updating that at time t-L, the parallel scan operates
566         # with a period of L. To do so we split the time indexing in
567         # two axes, the second of size CL, and run the parallel scan
568         # using the other as the sequence index.
569
570         A = A.unflatten(2, (-1, CL))
571         gated_V = gated_V.unflatten(2, (-1, CL))
572         gated_K = gated_K.unflatten(2, (-1, CL))
573
574         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
575         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
576
577         # Put back the sequence index
578
579         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
580         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
581
582         ######################################################################
583         # compute the readout
584
585         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
586
587         # We build tensors NxHxTxFxL where N is the sample index, H
588         # the head, T the time, F the row in the caterpillar, and L
589         # the column in the caterpillar
590
591         windowed_V = moving_window(
592             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
593         )
594
595         windowed_K = moving_window(
596             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
597         )
598
599         # We have an attention score for each of the CHxCL values
600
601         ar = torch.einsum(
602             "nhtd,nftld->nhtfl",
603             Q,
604             windowed_K,
605         ) / math.sqrt(DK)
606
607         # softmax can operate only on one dimension, hence the
608         # flattening
609
610         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
611
612         ar = F.dropout(ar, self.attention_dropout, self.training)
613
614         # Compute the output for each head, flatten to concatenate
615
616         Y = torch.einsum(
617             "nhtfl,nftld->nthd",
618             ar,
619             windowed_V,
620         ).flatten(2)
621
622         # Compute the final output
623
624         self.cache_Y[:, t0:t1] = Y @ self.w_O
625
626         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
627
628
629 ##############################
630
631
632 class QKVAttention(nn.Module):
633     def __init__(
634         self,
635         dim_model,
636         dim_qk,
637         dim_v,
638         nb_heads=1,
639         causal=False,
640         attention_dropout=0.0,
641     ):
642         super().__init__()
643
644         def randw(*d):
645             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
646
647         self.causal = causal
648         self.attention_dropout = attention_dropout
649         self.record_attention = False
650
651         self.w_q = randw(nb_heads, dim_qk, dim_model)
652         self.w_k = randw(nb_heads, dim_qk, dim_model)
653         self.w_v = randw(nb_heads, dim_v, dim_model)
654         self.w_o = randw(dim_v * nb_heads, dim_model)
655
656     def forward(self, bs):
657         x_q = bs.x
658
659         assert (
660             self.causal or bs.complete()
661         ), "Partial evaluation is only possible for causal models"
662
663         if bs.init_cache:
664             self.cache_k = x_q.new_zeros(
665                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
666             )
667             self.cache_v = x_q.new_zeros(
668                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
669             )
670             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
671
672         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
673
674         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
675             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
676         )
677         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
678             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
679         )
680
681         a = torch.einsum(
682             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
683         ) / math.sqrt(self.w_q.size(1))
684
685         if self.causal:
686             if bs.init_cache:
687                 self.cache_attzero = (
688                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
689                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
690                 )
691             a = a.masked_fill(
692                 self.cache_attzero[
693                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
694                 ],
695                 float("-inf"),
696             )
697
698         a = a.softmax(dim=3)
699
700         if self.record_attention:
701             self.a = a
702
703         a = F.dropout(a, self.attention_dropout, self.training)
704
705         y = torch.einsum(
706             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
707         ).flatten(2)
708
709         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
710
711         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
712
713
714 ##############################
715
716
717 class MyGPT(nn.Module):
718     def __init__(
719         self,
720         vocabulary_size,
721         dim_model,
722         dim_keys,
723         dim_hidden,
724         nb_heads,
725         nb_blocks,
726         nb_lines=None,
727         caterpillar_height=None,
728         dim_rec_v=-1,
729         causal=False,
730         dropout=0.0,
731         len_max=1e5,
732         attention_layer="kvrec",
733     ):
734         super().__init__()
735
736         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
737
738         if attention_layer == "caterpillar":
739             assert nb_lines % caterpillar_height == 0
740             self.caterpillar_length = nb_lines // caterpillar_height
741             self.caterpillar_height = caterpillar_height
742         else:
743             self.caterpillar_length = -1
744             self.caterpillar_height = -1
745
746         assert dim_model % nb_heads == 0
747
748         self.embedding = nn.Sequential(
749             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
750             AddPositionalEncoding(len_max),
751         )
752
753         trunk_blocks = []
754
755         def attlayer():
756             if attention_layer == "mha":
757                 return QKVAttention(
758                     dim_model=dim_model,
759                     dim_qk=dim_keys,
760                     dim_v=dim_model // nb_heads,
761                     nb_heads=nb_heads,
762                     causal=causal,
763                     attention_dropout=dropout,
764                 )
765             elif attention_layer == "dumbrec":
766                 return DumbRec(
767                     dim_model=dim_model,
768                     dim_qk=dim_keys,
769                     dim_v=dim_rec_v,
770                     nb_heads=nb_heads,
771                     nb_lines=nb_lines,
772                     attention_dropout=dropout,
773                 )
774             elif attention_layer == "kvrec":
775                 return KVRec(
776                     dim_model=dim_model,
777                     dim_qk=dim_keys,
778                     dim_v=dim_rec_v,
779                     nb_heads=nb_heads,
780                     nb_lines=nb_lines,
781                     attention_dropout=dropout,
782                 )
783             elif attention_layer == "caterpillar":
784                 return Caterpillar(
785                     dim_model=dim_model,
786                     dim_qk=dim_keys,
787                     dim_v=dim_rec_v,
788                     nb_heads=nb_heads,
789                     caterpillar_length=self.caterpillar_length,
790                     caterpillar_height=self.caterpillar_height,
791                     attention_dropout=dropout,
792                 )
793             else:
794                 raise ValueError(f"Unknown attention type {attention_layer}.")
795
796         for b in range(nb_blocks):
797             trunk_blocks += [
798                 WithResidual(
799                     CacheWrapper(nn.LayerNorm((dim_model,))),
800                     attlayer(),
801                 ),
802                 WithResidual(
803                     CacheWrapper(
804                         nn.LayerNorm((dim_model,)),
805                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
806                         nn.ReLU(),
807                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
808                         nn.Dropout(dropout),
809                     ),
810                 ),
811             ]
812
813         self.trunk = nn.Sequential(*trunk_blocks)
814
815         self.readout = CacheWrapper(
816             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
817         )
818
819         with torch.no_grad():
820             for m in self.modules():
821                 if isinstance(m, nn.Embedding):
822                     m.weight.normal_(mean=0, std=2e-2)
823                 elif isinstance(m, nn.LayerNorm):
824                     m.bias.zero_()
825                     m.weight.fill_(1.0)
826
827         self.reset_inner_loss()
828
829     def forward(self, bs):
830         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
831
832         # To make the code simpler in the Caterpillar layer, we pad
833         # here. It's unclear if/how much it hurts computationaly by
834         # increasing the sequence length for the other layers
835
836         if self.caterpillar_length > 0:
837             original_nb = bs.nb
838             if bs.nb % self.caterpillar_length > 0:
839                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
840
841             bs = BracketedSequence(
842                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
843                 bs.first + self.caterpillar_length,
844                 bs.nb,
845                 bs.init_cache,
846             )
847
848         bs = self.embedding(bs)
849         bs = self.trunk(bs)
850         bs = self.readout(bs)
851
852         if self.caterpillar_length > 0:
853             bs = BracketedSequence(
854                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
855                 bs.first - self.caterpillar_length,
856                 original_nb,
857                 bs.init_cache,
858             )
859
860         return bs
861
862     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
863     # 1s where tokens should be generated. The others are kept
864     # unchanged.
865
866     def masked_inplace_autoregression(
867         self,
868         input_src,
869         ar_mask_src,
870         forbidden_tokens=None,
871         deterministic_synthesis=False,
872     ):
873         input = input_src.to(self.readout.f.weight.device)
874         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
875         to_generate = (ar_mask.sum(0) > 0).nonzero()
876         if to_generate.min() > 0:
877             self(
878                 BracketedSequence(input, 0, to_generate.min(), True)
879             )  # Needed to initialize the model's cache
880         for s in range(to_generate.min(), to_generate.max() + 1):
881             output = self(BracketedSequence(input, s, 1, s == 0)).x
882             logits = output[:, s]
883             if forbidden_tokens is not None:
884                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
885             if deterministic_synthesis:
886                 t_next = logits.argmax(1)
887             else:
888                 dist = torch.distributions.categorical.Categorical(logits=logits)
889                 t_next = dist.sample()
890             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
891
892         input_src.copy_(input)
893
894     def reset_inner_loss(self):
895         for m in self.modules():
896             if m is not self and hasattr(m, "reset_inner_loss"):
897                 m.reset_inner_loss()
898
899     def get_inner_loss(self):
900         l = torch.tensor([0.0], device=self.readout.f.weight.device)
901         for m in self.modules():
902             if m is not self and hasattr(m, "get_inner_loss"):
903                 l += m.get_inner_loss()
904         return l
905
906     def record_attention(self, v=True):
907         for m in self.modules():
908             if isinstance(m, QKVAttention):
909                 m.record_attention = v
910
911     def retrieve_attention(self):
912         a = []
913         for m in self.modules():
914             if isinstance(m, QKVAttention):
915                 a.append(m.a)
916         return a
917
918
919 ######################################################################
920
921 if __name__ == "__main__":
922     print("Basic check.")
923
924     m = Caterpillar(
925         dim_model=4,
926         dim_qk=3,
927         dim_v=7,
928         nb_heads=1,
929         caterpillar_length=7,
930         caterpillar_height=3,
931         attention_dropout=0.0,
932     )
933
934     m.reset_inner_loss()
935     x = torch.randn(1, 21 + 2 * 7, 4)
936     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
937     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
938     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
939     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
940     print((y1 - y2).abs().max())
941     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
942     exit(0)
943
944     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
945
946     vocabulary_size = 128
947     x = torch.randint(vocabulary_size, (6, 1024))
948
949     model = MyGPT(
950         vocabulary_size=vocabulary_size,
951         dim_model=512,
952         dim_keys=64,
953         dim_hidden=2048,
954         nb_heads=8,
955         nb_lines=128,
956         nb_blocks=12,
957         dropout=0.1,
958         causal=True,
959     )
960
961     x = x.to(device)
962     model.to(device)
963
964     import time, sys
965
966     # import torchvision.models as models
967     # from torch.profiler import profile, record_function, ProfilerActivity
968
969     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
970     # with record_function("model_inference"):
971
972     model.eval()
973     for i in range(3):
974         start_time = time.perf_counter()
975         for k in range(10):
976             model(BracketedSequence(x))
977         duration = time.perf_counter() - start_time
978         print(duration)
979         sys.stdout.flush()
980
981     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
982     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
983
984     # print("##############################################################")
985     # y2 = torch.randn_like(y1)
986     # for s in range(x.size(1)):
987     # z = model(BracketedSequence(x, s, 1))
988     # y2[:, s : s + 1] = z.slice()
989
990     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
991
992 ######################################################################