Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_model,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_model)
203         self.w_qr = randw(nb_heads, dim_qk, dim_model)
204         # self.w_k = randw(nb_heads, dim_qk, dim_model)
205         self.w_v = randw(nb_heads, dim_v, dim_model)
206         self.w_o = randw(dim_v * nb_heads, dim_model)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_model,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_model)
332         self.w_qr = randw(nb_heads, dim_qk, dim_model)
333         self.w_k = randw(nb_heads, dim_qk, dim_model)
334         self.w_v = randw(nb_heads, dim_v, dim_model)
335         self.w_o = randw(dim_v * nb_heads, dim_model)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
447
448
449 def moving_window(x, dim, win_dim, win_size):
450     size, stride = x.size(), x.stride()
451     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452     size = size[:win_dim] + (win_size,) + size[win_dim:]
453     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
454
455     return x.as_strided(size=size, stride=stride)
456
457
458 ##############################
459
460 # This is one order of magnitude more complicated than I expected
461
462
463 def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device):
464     # starting flash backs
465     fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long()
466     fb_start[:, :, -CL:] = 0
467     fb_start[:, :, :CL] = 0
468
469     # Remove series longer than CL
470     fb_body = fb_start.clone()
471     fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)]
472     fb_body = fb_body.cumsum(dim=2)
473     fb_start = fb_start * (fb_body == 1)
474
475     # Set a origin source time (at what time start the chunck to copy
476     # here) We set it as a multiple of CL to be consistent with the
477     # "rolling" caterpillar
478     t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :]
479     src_time = fb_start * (
480         t
481         - CL
482         * (
483             1
484             + (
485                 torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1)
486             ).long()
487         )
488     )
489     src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL]
490     src_time = src_time.cumsum(dim=2)
491
492     src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device)
493     src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL]
494     src_head = src_head.cumsum(dim=2)
495
496     # combine
497     src_delta = fb_start.clone()
498     src_delta[:, :, CL:] -= fb_start[:, :, :-CL]
499     src_delta = src_delta.cumsum(dim=2)
500     src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL]
501     src_time += src_delta.cumsum(dim=2) - 1
502
503     return src_time, src_head
504
505
506 def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
507     N, H, CH = V.size(0), V.size(1), rec_V.size(1)
508
509     fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device)
510
511     fbt_V = fbt[:, :, :, None].expand_as(rec_V[:, :, t0:t1])
512     fbh_V = fbh[:, :, :, None].expand_as(rec_V[:, :, t0:t1])
513     t = fbt_V.clamp(min=0)
514     n = torch.arange(V.size(0), device=V.device)[:, None, None, None].expand_as(
515         rec_V[:, :, t0:t1]
516     )
517     d = torch.arange(V.size(3), device=V.device)[None, None, None, :].expand_as(
518         rec_V[:, :, t0:t1]
519     )
520     q = V[:, :, t0:t1][n, fbh_V, t, d]
521     rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0)
522
523     fbt_K = fbt[:, :, :, None].expand_as(rec_K[:, :, t0:t1])
524     fbh_K = fbh[:, :, :, None].expand_as(rec_K[:, :, t0:t1])
525     t = fbt_K.clamp(min=0)
526     n = torch.arange(K.size(0), device=K.device)[:, None, None, None].expand_as(
527         rec_K[:, :, t0:t1]
528     )
529     d = torch.arange(K.size(3), device=K.device)[None, None, None, :].expand_as(
530         rec_K[:, :, t0:t1]
531     )
532     q = K[:, :, t0:t1][n, fbh_K, t, d]
533     rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0)
534
535     # print("SANITY", (fbt_K >=0).float().sum()/fbt_K.numel())
536
537
538 ######################################################################
539
540
541 class Caterpillar(nn.Module):
542     def __init__(
543         self,
544         dim_model,
545         dim_qk,
546         dim_v,
547         nb_heads,
548         caterpillar_length,
549         caterpillar_height,
550         attention_dropout=0.0,
551         len_max=1e5,
552     ):
553         super().__init__()
554
555         warnings.warn("Caterpillar", RuntimeWarning)
556
557         def randw(*d):
558             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
559
560         self.caterpillar_length = caterpillar_length
561         self.caterpillar_height = caterpillar_height
562         self.attention_dropout = attention_dropout
563
564         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
565         self.b_G = nn.Parameter(
566             torch.full(
567                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
568             )
569         )
570
571         self.w_K = randw(nb_heads, dim_qk, dim_model)
572         self.w_V = randw(nb_heads, dim_v, dim_model)
573         self.w_Q = randw(nb_heads, dim_qk, dim_model)
574         self.w_O = randw(dim_v * nb_heads, dim_model)
575
576         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
577         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
578
579     def reset_inner_loss(self):
580         self.acc_attention = 0
581         self.acc_nb = 0
582
583     def get_inner_loss(self):
584         # warnings.warn("l2 regularization", RuntimeWarning)
585         # return (self.acc_attention / self.acc_nb).pow(2).sum()
586         return torch.tensor([0], device=self.w_Q.device)
587
588     def forward(self, bs):
589         # Dimensions to make the source a bit clearer, that's needed
590
591         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
592
593         N = bs.x.size(0)
594         T = bs.x.size(1)
595         DV = self.w_V.size(1)
596         DK = self.w_K.size(1)
597         DM = self.w_O.size(1)
598         CH = self.caterpillar_height
599         CL = self.caterpillar_length
600
601         assert (
602             t0 >= CL and (t1 - t0) % CL == 0
603         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
604
605         # We cache values to deal efficiently with auto-regression
606
607         if bs.init_cache:
608             self.rec_V = X.new_zeros(N, CH, T, DV)
609             self.rec_K = X.new_zeros(N, CH, T, DK)
610             # We start the recurrent sequences with optimizable
611             # initial values. No idea if it helps.
612             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
613             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
614
615             self.cache_Y = X.new_zeros(N, T, DM)
616
617         ######################################################################
618         # Compute the recurrent state
619
620         # This is the Gating sequence that modulates the storing of
621         # the new key and value in the CH pairs of the current
622         # stack. The CH gating values are independent, which means
623         # that the current K/V could be stored in multiple pairs of the
624         # recurrent state, or not at all.
625
626         G = (
627             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
628         ).sigmoid()
629
630         # That bas a bad idea
631         # G = F.dropout(G, self.attention_dropout, self.training)
632
633         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
634         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
635
636         # We prepare the arguments for the parallel scan
637
638         A = 1 - G.sum(1)
639         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
640         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
641
642         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
643         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
644
645         # Here there is a trick: Since the stack at time t is computed
646         # by updating that at time t-L, the parallel scan operates
647         # with a period of L. To do so we split the time indexing in
648         # two axes, the second of size CL, and run the parallel scan
649         # using the other as the sequence index.
650
651         A = A.unflatten(2, (-1, CL))
652         gated_V = gated_V.unflatten(2, (-1, CL))
653         gated_K = gated_K.unflatten(2, (-1, CL))
654
655         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
656         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
657
658         # Put back the sequence index
659
660         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
661         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
662
663         warnings.warn("flash back", RuntimeWarning)
664         if self.training:
665             insert_flash_back(self.rec_V, V, self.rec_K, K, t0, t1, CL, proba=1e-2 / CL)
666
667         ######################################################################
668         # compute the readout
669
670         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
671
672         # We build tensors NxHxTxFxL where N is the sample index, H
673         # the head, T the time, F the row in the caterpillar, and L
674         # the column in the caterpillar
675
676         windowed_V = moving_window(
677             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
678         )
679
680         windowed_K = moving_window(
681             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
682         )
683
684         # We have an attention score for each of the CHxCL values
685
686         ar = torch.einsum(
687             "nhtd,nftld->nhtfl",
688             Q,
689             windowed_K,
690         ) / math.sqrt(DK)
691
692         # softmax can operate only on one dimension, hence the
693         # flattening
694
695         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
696
697         ar = F.dropout(ar, self.attention_dropout, self.training)
698
699         # Compute the output for each head, flatten to concatenate
700
701         Y = torch.einsum(
702             "nhtfl,nftld->nthd",
703             ar,
704             windowed_V,
705         ).flatten(2)
706
707         # Compute the final output
708
709         self.cache_Y[:, t0:t1] = Y @ self.w_O
710
711         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
712
713
714 ##############################
715
716
717 class QKVAttention(nn.Module):
718     def __init__(
719         self,
720         dim_model,
721         dim_qk,
722         dim_v,
723         nb_heads=1,
724         causal=False,
725         attention_dropout=0.0,
726     ):
727         super().__init__()
728
729         def randw(*d):
730             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
731
732         self.causal = causal
733         self.attention_dropout = attention_dropout
734         self.record_attention = False
735
736         self.w_q = randw(nb_heads, dim_qk, dim_model)
737         self.w_k = randw(nb_heads, dim_qk, dim_model)
738         self.w_v = randw(nb_heads, dim_v, dim_model)
739         self.w_o = randw(dim_v * nb_heads, dim_model)
740
741     def forward(self, bs):
742         x_q = bs.x
743
744         assert (
745             self.causal or bs.complete()
746         ), "Partial evaluation is only possible for causal models"
747
748         if bs.init_cache:
749             self.cache_k = x_q.new_zeros(
750                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
751             )
752             self.cache_v = x_q.new_zeros(
753                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
754             )
755             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
756
757         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
758
759         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
760             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
761         )
762         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
763             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
764         )
765
766         a = torch.einsum(
767             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
768         ) / math.sqrt(self.w_q.size(1))
769
770         if self.causal:
771             if bs.init_cache:
772                 self.cache_attzero = (
773                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
774                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
775                 )
776             a = a.masked_fill(
777                 self.cache_attzero[
778                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
779                 ],
780                 float("-inf"),
781             )
782
783         a = a.softmax(dim=3)
784
785         if self.record_attention:
786             self.a = a
787
788         a = F.dropout(a, self.attention_dropout, self.training)
789
790         y = torch.einsum(
791             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
792         ).flatten(2)
793
794         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
795
796         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
797
798
799 ##############################
800
801
802 class MyGPT(nn.Module):
803     def __init__(
804         self,
805         vocabulary_size,
806         dim_model,
807         dim_keys,
808         dim_hidden,
809         nb_heads,
810         nb_blocks,
811         nb_lines=None,
812         caterpillar_height=None,
813         dim_rec_v=-1,
814         causal=False,
815         dropout=0.0,
816         len_max=1e5,
817         attention_layer="kvrec",
818     ):
819         super().__init__()
820
821         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
822
823         if attention_layer == "caterpillar":
824             assert nb_lines % caterpillar_height == 0
825             self.caterpillar_length = nb_lines // caterpillar_height
826             self.caterpillar_height = caterpillar_height
827         else:
828             self.caterpillar_length = -1
829             self.caterpillar_height = -1
830
831         assert dim_model % nb_heads == 0
832
833         self.embedding = nn.Sequential(
834             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
835             AddPositionalEncoding(len_max),
836         )
837
838         trunk_blocks = []
839
840         def attlayer():
841             if attention_layer == "mha":
842                 return QKVAttention(
843                     dim_model=dim_model,
844                     dim_qk=dim_keys,
845                     dim_v=dim_model // nb_heads,
846                     nb_heads=nb_heads,
847                     causal=causal,
848                     attention_dropout=dropout,
849                 )
850             elif attention_layer == "dumbrec":
851                 return DumbRec(
852                     dim_model=dim_model,
853                     dim_qk=dim_keys,
854                     dim_v=dim_rec_v,
855                     nb_heads=nb_heads,
856                     nb_lines=nb_lines,
857                     attention_dropout=dropout,
858                 )
859             elif attention_layer == "kvrec":
860                 return KVRec(
861                     dim_model=dim_model,
862                     dim_qk=dim_keys,
863                     dim_v=dim_rec_v,
864                     nb_heads=nb_heads,
865                     nb_lines=nb_lines,
866                     attention_dropout=dropout,
867                 )
868             elif attention_layer == "caterpillar":
869                 return Caterpillar(
870                     dim_model=dim_model,
871                     dim_qk=dim_keys,
872                     dim_v=dim_rec_v,
873                     nb_heads=nb_heads,
874                     caterpillar_length=self.caterpillar_length,
875                     caterpillar_height=self.caterpillar_height,
876                     attention_dropout=dropout,
877                 )
878             else:
879                 raise ValueError(f"Unknown attention type {attention_layer}.")
880
881         for b in range(nb_blocks):
882             trunk_blocks += [
883                 WithResidual(
884                     CacheWrapper(nn.LayerNorm((dim_model,))),
885                     attlayer(),
886                 ),
887                 WithResidual(
888                     CacheWrapper(
889                         nn.LayerNorm((dim_model,)),
890                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
891                         nn.ReLU(),
892                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
893                         nn.Dropout(dropout),
894                     ),
895                 ),
896             ]
897
898         self.trunk = nn.Sequential(*trunk_blocks)
899
900         self.readout = CacheWrapper(
901             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
902         )
903
904         with torch.no_grad():
905             for m in self.modules():
906                 if isinstance(m, nn.Embedding):
907                     m.weight.normal_(mean=0, std=2e-2)
908                 elif isinstance(m, nn.LayerNorm):
909                     m.bias.zero_()
910                     m.weight.fill_(1.0)
911
912         self.reset_inner_loss()
913
914     def forward(self, bs):
915         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
916
917         # To make the code simpler in the Caterpillar layer, we pad
918         # here. It's unclear if/how much it hurts computationaly by
919         # increasing the sequence length for the other layers
920
921         if self.caterpillar_length > 0:
922             original_nb = bs.nb
923             if bs.nb % self.caterpillar_length > 0:
924                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
925
926             bs = BracketedSequence(
927                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
928                 bs.first + self.caterpillar_length,
929                 bs.nb,
930                 bs.init_cache,
931             )
932
933         bs = self.embedding(bs)
934         bs = self.trunk(bs)
935         bs = self.readout(bs)
936
937         if self.caterpillar_length > 0:
938             bs = BracketedSequence(
939                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
940                 bs.first - self.caterpillar_length,
941                 original_nb,
942                 bs.init_cache,
943             )
944
945         return bs
946
947     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
948     # 1s where tokens should be generated. The others are kept
949     # unchanged.
950
951     def masked_inplace_autoregression(
952         self,
953         input_src,
954         ar_mask_src,
955         forbidden_tokens=None,
956         deterministic_synthesis=False,
957     ):
958         input = input_src.to(self.readout.f.weight.device)
959         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
960         to_generate = (ar_mask.sum(0) > 0).nonzero()
961         if to_generate.min() > 0:
962             self(
963                 BracketedSequence(input, 0, to_generate.min(), True)
964             )  # Needed to initialize the model's cache
965         for s in range(to_generate.min(), to_generate.max() + 1):
966             output = self(BracketedSequence(input, s, 1, s == 0)).x
967             logits = output[:, s]
968             if forbidden_tokens is not None:
969                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
970             if deterministic_synthesis:
971                 t_next = logits.argmax(1)
972             else:
973                 dist = torch.distributions.categorical.Categorical(logits=logits)
974                 t_next = dist.sample()
975             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
976
977         input_src.copy_(input)
978
979     def reset_inner_loss(self):
980         for m in self.modules():
981             if m is not self and hasattr(m, "reset_inner_loss"):
982                 m.reset_inner_loss()
983
984     def get_inner_loss(self):
985         l = torch.tensor([0.0], device=self.readout.f.weight.device)
986         for m in self.modules():
987             if m is not self and hasattr(m, "get_inner_loss"):
988                 l += m.get_inner_loss()
989         return l
990
991     def record_attention(self, v=True):
992         for m in self.modules():
993             if isinstance(m, QKVAttention):
994                 m.record_attention = v
995
996     def retrieve_attention(self):
997         a = []
998         for m in self.modules():
999             if isinstance(m, QKVAttention):
1000                 a.append(m.a)
1001         return a
1002
1003
1004 ######################################################################
1005
1006 if __name__ == "__main__":
1007     print("Basic check.")
1008
1009     m = Caterpillar(
1010         dim_model=4,
1011         dim_qk=3,
1012         dim_v=7,
1013         nb_heads=1,
1014         caterpillar_length=7,
1015         caterpillar_height=3,
1016         attention_dropout=0.0,
1017     )
1018
1019     m.reset_inner_loss()
1020     x = torch.randn(1, 21 + 2 * 7, 4)
1021     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1022     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1023     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1024     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1025     print((y1 - y2).abs().max())
1026     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1027     exit(0)
1028
1029     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1030
1031     vocabulary_size = 128
1032     x = torch.randint(vocabulary_size, (6, 1024))
1033
1034     model = MyGPT(
1035         vocabulary_size=vocabulary_size,
1036         dim_model=512,
1037         dim_keys=64,
1038         dim_hidden=2048,
1039         nb_heads=8,
1040         nb_lines=128,
1041         nb_blocks=12,
1042         dropout=0.1,
1043         causal=True,
1044     )
1045
1046     x = x.to(device)
1047     model.to(device)
1048
1049     import time, sys
1050
1051     # import torchvision.models as models
1052     # from torch.profiler import profile, record_function, ProfilerActivity
1053
1054     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1055     # with record_function("model_inference"):
1056
1057     model.eval()
1058     for i in range(3):
1059         start_time = time.perf_counter()
1060         for k in range(10):
1061             model(BracketedSequence(x))
1062         duration = time.perf_counter() - start_time
1063         print(duration)
1064         sys.stdout.flush()
1065
1066     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1067     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1068
1069     # print("##############################################################")
1070     # y2 = torch.randn_like(y1)
1071     # for s in range(x.size(1)):
1072     # z = model(BracketedSequence(x, s, 1))
1073     # y2[:, s : s + 1] = z.slice()
1074
1075     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1076
1077 ######################################################################