760a3c60e086f094d4e86672e040877f94c6a6d2
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150     with torch.no_grad():
151         s_A, s_X = 0, 0
152         for t in range(X.size(dim) - 1, 0, -1):
153             delta = (grad_Y[t] - s_A) / A[t].grad
154             s_A += A[t].grad * delta
155             A[t].grad = delta
156             delta = (grad_Y[t] - s_X) / X[t].grad
157             s_X += X[t].grad * delta
158             X[t].grad = delta
159
160
161 def pscan_shape(A, X, Y_init):
162     s = X.size()
163     A = A.reshape(-1, s[-2])
164     X = X.reshape(-1, s[-2], s[-1])
165
166     if Y_init is None:
167         Y_init = X.new_zeros(X.size(0), s[-1])
168     else:
169         Y_init = Y_init.reshape(-1, s[-1])
170
171     Y = pscan.pscan(A, X, Y_init).reshape(s)
172
173     return Y
174
175
176 def nsum_shape(X, Y_init):
177     s = X.size()
178     X = X.reshape(-1, s[-2], s[-1])  # ntd
179
180     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
181     result = []
182
183     for k in range(X.size(1)):
184         Y = Y + X[:, k]
185         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
186         result.append(Y)
187
188     return torch.cat(result, dim=1).reshape(s)
189
190
191 ##############################
192
193
194 class DumbRec(nn.Module):
195     def __init__(
196         self,
197         dim_model,
198         dim_qk,
199         dim_v,
200         nb_heads,
201         nb_lines,
202         attention_dropout=0.0,
203         len_max=1e5,
204         logger=print,
205         args=None,
206     ):
207         super().__init__()
208
209         def randw(*d):
210             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
211
212         self.nb_lines = nb_lines
213         self.attention_dropout = attention_dropout
214
215         self.k_star = randw(nb_lines, dim_qk)
216
217         self.w_qw = randw(nb_heads, dim_qk, dim_model)
218         self.w_qr = randw(nb_heads, dim_qk, dim_model)
219         # self.w_k = randw(nb_heads, dim_qk, dim_model)
220         self.w_v = randw(nb_heads, dim_v, dim_model)
221         self.w_o = randw(dim_v * nb_heads, dim_model)
222
223     def reset_inner_loss(self):
224         self.acc_attention = 0
225         self.acc_nb = 0
226
227     def get_inner_loss(self):
228         warnings.warn("l2 regularization", RuntimeWarning)
229         return (self.acc_attention / self.acc_nb).pow(2).sum()
230         # return torch.tensor([0], device=self.w_qw.device)
231
232     def forward(self, bs):
233         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
234
235         if bs.init_cache:
236             self.rec_v = x_q.new_zeros(
237                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
238             )
239             # self.rec_k = x_q.new_zeros(
240             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
241             # )
242             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
243
244         ######################################################################
245         # Prepare the keys
246
247         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
248
249         warnings.warn("rotating key barrel", RuntimeWarning)
250         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251         t_barrel = torch.arange(t0, t1, device=k_star.device)
252         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
253         l_barrel = (
254             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
255         ) % k_star.size(0)
256         k_star = k_star[l_barrel, t_barrel]
257
258         ######################################################################
259         # Compute the recurrent state
260
261         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
262
263         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
265
266         aw = torch.einsum(
267             "nhtd,ltd->nhlt",
268             qw,
269             k_star,
270         ) / math.sqrt(self.w_qw.size(1))
271
272         aw = aw.softmax(dim=2)  # nhlt
273
274         if self.train:
275             self.acc_attention += aw.sum(dim=(0, 1, 3))
276             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
277
278         aw = F.dropout(aw, self.attention_dropout, self.training)
279
280         A = 1 - aw.sum(dim=1)  # nlt
281
282         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
284
285         if t0 == 0:
286             V0 = None
287             # K0 = None
288         else:
289             V0 = self.rec_v[:, :, t0 - 1]
290             # K0 = self.rec_k[:, :, t0 - 1]
291
292         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
294
295         ######################################################################
296         # compute the readout
297
298         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
299
300         ar = torch.einsum(
301             "nhtd,ld->nhlt",
302             qr,
303             # self.rec_k[:, :, t0:t1],
304             self.k_star,
305         ) / math.sqrt(self.w_qr.size(1))
306
307         ar = ar.softmax(dim=2)  # nhlt
308
309         ar = F.dropout(ar, self.attention_dropout, self.training)
310
311         y = torch.einsum(
312             "nhlt,nltd->nthd",
313             ar,
314             self.rec_v[:, :, t0:t1],
315         ).flatten(2)
316
317         self.cache_y[:, t0:t1] = y @ self.w_o
318
319         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
320
321
322 ##############################
323
324
325 class KVRec(nn.Module):
326     def __init__(
327         self,
328         dim_model,
329         dim_qk,
330         dim_v,
331         nb_heads,
332         nb_lines,
333         attention_dropout=0.0,
334         len_max=1e5,
335         logger=print,
336         args=None,
337     ):
338         super().__init__()
339
340         def randw(*d):
341             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
342
343         self.nb_lines = nb_lines
344         self.attention_dropout = attention_dropout
345
346         self.k_star = randw(nb_lines, dim_qk)
347
348         self.w_qw = randw(nb_heads, dim_qk, dim_model)
349         self.w_qr = randw(nb_heads, dim_qk, dim_model)
350         self.w_k = randw(nb_heads, dim_qk, dim_model)
351         self.w_v = randw(nb_heads, dim_v, dim_model)
352         self.w_o = randw(dim_v * nb_heads, dim_model)
353
354     def reset_inner_loss(self):
355         self.acc_attention = 0
356         self.acc_nb = 0
357
358     def get_inner_loss(self):
359         warnings.warn("l2 regularization", RuntimeWarning)
360         return (self.acc_attention / self.acc_nb).pow(2).sum()
361         # return torch.tensor([0], device=self.w_qw.device)
362         # warnings.warn("side regularization", RuntimeWarning)
363         # return (
364         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
365         # )
366         # return torch.tensor([0], device=self.w_qw.device)
367
368     def forward(self, bs):
369         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
370
371         if bs.init_cache:
372             self.rec_v = x_q.new_zeros(
373                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
374             )
375             self.rec_k = x_q.new_zeros(
376                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
377             )
378             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
379
380         ######################################################################
381         # Prepare the keys
382
383         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
384
385         warnings.warn("rotating key barrel", RuntimeWarning)
386         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387         t_barrel = torch.arange(t0, t1, device=k_star.device)
388         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
389         l_barrel = (
390             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
391         ) % k_star.size(0)
392         k_star = k_star[l_barrel, t_barrel]
393
394         ######################################################################
395         # Compute the recurrent state
396
397         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
398
399         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
401
402         aw = torch.einsum(
403             "nhtd,ltd->nhlt",
404             qw,
405             k_star,
406         ) / math.sqrt(self.w_qw.size(1))
407
408         aw = aw.softmax(dim=2)  # nhlt
409
410         if self.train:
411             # We want all the memory lines to be used similarly
412             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
413             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
414
415         aw = F.dropout(aw, self.attention_dropout, self.training)
416
417         A = 1 - aw.sum(dim=1)  # nlt
418
419         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
421
422         if t0 == 0:
423             V0 = None
424             K0 = None
425         else:
426             V0 = self.rec_v[:, :, t0 - 1]
427             K0 = self.rec_k[:, :, t0 - 1]
428
429         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
431
432         ######################################################################
433         # compute the readout
434
435         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
436
437         ar = torch.einsum(
438             "nhtd,nltd->nhlt",
439             qr,
440             self.rec_k[:, :, t0:t1],
441         ) / math.sqrt(self.w_qr.size(1))
442
443         ar = ar.softmax(dim=2)  # nhlt
444
445         ar = F.dropout(ar, self.attention_dropout, self.training)
446
447         y = torch.einsum(
448             "nhlt,nltd->nthd",
449             ar,
450             self.rec_v[:, :, t0:t1],
451         ).flatten(2)
452
453         self.cache_y[:, t0:t1] = y @ self.w_o
454
455         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
456
457
458 ##############################
459
460
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
464
465
466 def moving_window(x, dim, win_dim, win_size):
467     size, stride = x.size(), x.stride()
468     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469     size = size[:win_dim] + (win_size,) + size[win_dim:]
470     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
471
472     return x.as_strided(size=size, stride=stride)
473
474
475 ##############################
476
477
478 class Caterpillar(nn.Module):
479     def __init__(
480         self,
481         dim_model,
482         dim_qk,
483         dim_v,
484         nb_heads,
485         caterpillar_length,
486         caterpillar_height,
487         attention_dropout=0.0,
488         len_max=1e5,
489         logger=print,
490         args=None,
491     ):
492         super().__init__()
493
494         warnings.warn("Caterpillar", RuntimeWarning)
495
496         def randw(*d, amplitude=None):
497             if amplitude is None:
498                 amplitude = 1 / math.sqrt(d[-1])
499             return nn.Parameter(amplitude * torch.randn(*d))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         self.gate_dropout_proba = args.gate_dropout_proba
506         self.gate_dropout_sync = args.gate_dropout_sync
507
508         ######################################################################
509
510         default_bg = -math.log(caterpillar_height - 1)
511         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
512         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
513
514         self.w_K = randw(nb_heads, dim_qk, dim_model)
515         self.w_V = randw(nb_heads, dim_v, dim_model)
516         self.w_Q = randw(nb_heads, dim_qk, dim_model)
517         self.w_O = randw(dim_v * nb_heads, dim_model)
518
519         self.init_K_rec = randw(
520             caterpillar_height,
521             caterpillar_length,
522             dim_qk,
523         )
524         self.init_V_rec = randw(
525             caterpillar_height,
526             caterpillar_length,
527             dim_v,
528         )
529
530     def reset_inner_loss(self):
531         self.acc_attention = 0
532         self.acc_nb = 0
533
534     def get_inner_loss(self):
535         # warnings.warn("l2 regularization", RuntimeWarning)
536         # return (self.acc_attention / self.acc_nb).pow(2).sum()
537         return torch.tensor([0], device=self.w_Q.device)
538
539     def forward(self, bs):
540         # Dimensions to make the source a bit clearer, that's needed
541
542         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
543
544         N = bs.x.size(0)
545         T = bs.x.size(1)
546         H = self.w_V.size(0)
547         DV = self.w_V.size(1)
548         DK = self.w_K.size(1)
549         DM = self.w_O.size(1)
550         R = self.caterpillar_height
551         L = self.caterpillar_length
552
553         assert (
554             t0 >= L and (t1 - t0) % L == 0
555         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
556
557         # We cache values to deal efficiently with auto-regression
558
559         if bs.init_cache:
560             self.rec_V = X.new_zeros(N, R, T, DV)
561             self.rec_K = X.new_zeros(N, R, T, DK)
562             # We start the recurrent sequences with optimizable
563             # initial values. No idea if it helps.
564             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
565             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
566
567             self.cache_Y = X.new_zeros(N, T, DM)
568
569         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
570         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
571
572         ######################################################################
573         # Compute the recurrent state
574
575         # This is the Gating sequence that modulates the storing of
576         # the new key and value in the R pairs of the current
577         # stack. There are R independent gating values, which means
578         # that the current K/V may be stored in multiple pairs of the
579         # recurrent state, or not at all.
580
581         G = (
582             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
583         ).sigmoid()
584
585         # Clip the gating to avoid values greater than 1 when several
586         # heads hit the same row
587
588         G = G / G.sum(1, keepdim=True).clamp(min=1)
589
590         ######################################################################
591
592         def recurrence(G, V, K):
593             # We prepare the arguments for the parallel scan
594
595             A = 1 - G.sum(1)
596
597             gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
598             gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
599
600             # We start from cached values, which matters in inference
601
602             init_rec_V = self.rec_V[:, :, t0 - L : t0]
603             init_rec_K = self.rec_K[:, :, t0 - L : t0]
604
605             # Here there is a trick: Since the stack at position t is
606             # computed by updating that at position t-L, the parallel
607             # scan operates with a period of L. To do so we split the
608             # sequence indexing in two axes, the second of size L, and
609             # run the parallel scan using the first as the sequence index.
610
611             A = A.unflatten(2, (-1, L))
612             gated_V = gated_V.unflatten(2, (-1, L))
613             gated_K = gated_K.unflatten(2, (-1, L))
614
615             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
616             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
617
618             next_V = next_V.flatten(2, 3)
619             next_K = next_K.flatten(2, 3)
620
621             return next_V, next_K
622
623         #################################################################
624
625         next_V, next_K = recurrence(G, V, K)
626
627         if self.training and self.gate_dropout_proba > 0.0:
628             # G is NxHxRxT where r is the caterpillar's row.
629
630             warnings.warn("gate dropout", RuntimeWarning)
631
632             if self.gate_dropout_sync:
633                 shape_kill = (N, 1, 1)
634             else:
635                 shape_kill = (N, H, R)
636
637             # Pick a point in each of the NxHxR timeline and set this
638             # entry and the following to 1
639             kill = (
640                 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
641                 == 0
642             ).cumsum(dim=3)
643
644             # Keep these mask for only some of the NxHxR
645             kill = kill * (
646                 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
647             )
648
649             # The coefficient to keep are the complementary
650             mask = 1 - kill
651
652             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
653
654             next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
655                 1 - self.gate_dropout_proba
656             )
657             next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
658                 1 - self.gate_dropout_proba
659             )
660
661         self.rec_V[:, :, t0:t1] = next_V
662         self.rec_K[:, :, t0:t1] = next_K
663
664         ######################################################################
665         # compute the readout
666
667         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
668
669         # We build tensors NxHxTxRxL where N is the sample index, H
670         # the head, T the time, R the row in the caterpillar, and L
671         # the column in the caterpillar
672
673         windowed_V = moving_window(
674             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
675         )
676
677         windowed_K = moving_window(
678             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
679         )
680
681         # We have an attention score for each of the RxL values
682
683         ar = torch.einsum(
684             "nhtd,nrtld->nhtrl",
685             Q,
686             windowed_K,
687         ) / math.sqrt(DK)
688
689         # softmax can operate only on one dimension, hence the
690         # flattening
691
692         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
693
694         ar = F.dropout(ar, self.attention_dropout, self.training)
695
696         # Compute the output for each head, flatten to concatenate
697
698         Y = torch.einsum(
699             "nhtfl,nftld->nthd",
700             ar,
701             windowed_V,
702         ).flatten(2)
703
704         # Compute the final output
705
706         self.cache_Y[:, t0:t1] = Y @ self.w_O
707
708         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
709
710
711 ##############################
712
713
714 class QKVAttention(nn.Module):
715     def __init__(
716         self,
717         dim_model,
718         dim_qk,
719         dim_v,
720         nb_heads=1,
721         causal=False,
722         attention_dropout=0.0,
723         logger=print,
724         args=None,
725     ):
726         super().__init__()
727
728         def randw(*d):
729             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
730
731         self.causal = causal
732         self.attention_dropout = attention_dropout
733         self.record_attention = False
734
735         self.w_q = randw(nb_heads, dim_qk, dim_model)
736         self.w_k = randw(nb_heads, dim_qk, dim_model)
737         self.w_v = randw(nb_heads, dim_v, dim_model)
738         self.w_o = randw(dim_v * nb_heads, dim_model)
739
740     def forward(self, bs):
741         x_q = bs.x
742
743         assert (
744             self.causal or bs.complete()
745         ), "Partial evaluation is only possible for causal models"
746
747         if bs.init_cache:
748             self.cache_k = x_q.new_zeros(
749                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
750             )
751             self.cache_v = x_q.new_zeros(
752                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
753             )
754             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
755
756         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
757
758         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
759             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
760         )
761         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
762             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
763         )
764
765         a = torch.einsum(
766             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
767         ) / math.sqrt(self.w_q.size(1))
768
769         if self.causal:
770             if bs.init_cache:
771                 self.cache_attzero = (
772                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
773                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
774                 )
775             a = a.masked_fill(
776                 self.cache_attzero[
777                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
778                 ],
779                 float("-inf"),
780             )
781
782         a = a.softmax(dim=3)
783
784         if self.record_attention:
785             self.a = a
786
787         a = F.dropout(a, self.attention_dropout, self.training)
788
789         y = torch.einsum(
790             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
791         ).flatten(2)
792
793         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
794
795         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
796
797
798 ##############################
799
800
801 class MyGPT(nn.Module):
802     def __init__(
803         self,
804         vocabulary_size,
805         dim_model,
806         dim_keys,
807         dim_hidden,
808         nb_heads,
809         nb_blocks,
810         nb_lines=None,
811         caterpillar_height=None,
812         causal=False,
813         dropout=0.0,
814         len_max=1e5,
815         attention_layer="kvrec",
816         logger=print,
817         args=None,
818     ):
819         super().__init__()
820
821         assert attention_layer in {
822             "mha",
823             "dumbrec",
824             "kvrec",
825             "caterpillar",
826         }, f"Unknown attention operator {attention_layer}."
827
828         if attention_layer == "caterpillar":
829             assert nb_lines % caterpillar_height == 0
830             self.caterpillar_length = nb_lines // caterpillar_height
831             self.caterpillar_height = caterpillar_height
832         else:
833             self.caterpillar_length = -1
834             self.caterpillar_height = -1
835
836         assert dim_model % nb_heads == 0
837
838         self.embedding = nn.Sequential(
839             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
840             AddPositionalEncoding(len_max),
841         )
842
843         trunk_blocks = []
844
845         def attlayer():
846             if attention_layer == "mha":
847                 return QKVAttention(
848                     dim_model=dim_model,
849                     dim_qk=dim_keys,
850                     dim_v=dim_model // nb_heads,
851                     nb_heads=nb_heads,
852                     causal=causal,
853                     attention_dropout=dropout,
854                     logger=logger,
855                     args=args,
856                 )
857             elif attention_layer == "dumbrec":
858                 return DumbRec(
859                     dim_model=dim_model,
860                     dim_qk=dim_keys,
861                     dim_v=dim_model // nb_heads,
862                     nb_heads=nb_heads,
863                     nb_lines=nb_lines,
864                     attention_dropout=dropout,
865                     logger=logger,
866                     args=args,
867                 )
868             elif attention_layer == "kvrec":
869                 return KVRec(
870                     dim_model=dim_model,
871                     dim_qk=dim_keys,
872                     dim_v=dim_model // nb_heads,
873                     nb_heads=nb_heads,
874                     nb_lines=nb_lines,
875                     attention_dropout=dropout,
876                     logger=logger,
877                     args=args,
878                 )
879             elif attention_layer == "caterpillar":
880                 return Caterpillar(
881                     dim_model=dim_model,
882                     dim_qk=dim_keys,
883                     dim_v=dim_model // nb_heads,
884                     nb_heads=nb_heads,
885                     caterpillar_length=self.caterpillar_length,
886                     caterpillar_height=self.caterpillar_height,
887                     attention_dropout=dropout,
888                     logger=logger,
889                     args=args,
890                 )
891             else:
892                 raise ValueError(f"Unknown attention type {attention_layer}.")
893
894         for b in range(nb_blocks):
895             trunk_blocks += [
896                 WithResidual(
897                     CacheWrapper(nn.LayerNorm((dim_model,))),
898                     attlayer(),
899                 ),
900                 WithResidual(
901                     CacheWrapper(
902                         nn.LayerNorm((dim_model,)),
903                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
904                         nn.ReLU(),
905                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
906                         nn.Dropout(dropout),
907                     ),
908                 ),
909             ]
910
911         self.trunk = nn.Sequential(*trunk_blocks)
912
913         self.readout = CacheWrapper(
914             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
915         )
916
917         with torch.no_grad():
918             for m in self.modules():
919                 if isinstance(m, nn.Embedding):
920                     m.weight.normal_(mean=0, std=2e-2)
921                 elif isinstance(m, nn.LayerNorm):
922                     m.bias.zero_()
923                     m.weight.fill_(1.0)
924
925         self.reset_inner_loss()
926
927     def forward(self, bs):
928         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
929
930         # To make the code simpler in the Caterpillar layer, we pad
931         # here. It's unclear if/how much it hurts computationaly by
932         # increasing the sequence length for the other layers
933
934         if self.caterpillar_length > 0:
935             original_nb = bs.nb
936             if bs.nb % self.caterpillar_length > 0:
937                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
938
939             bs = BracketedSequence(
940                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
941                 bs.first + self.caterpillar_length,
942                 bs.nb,
943                 bs.init_cache,
944             )
945
946         bs = self.embedding(bs)
947         bs = self.trunk(bs)
948         bs = self.readout(bs)
949
950         if self.caterpillar_length > 0:
951             bs = BracketedSequence(
952                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
953                 bs.first - self.caterpillar_length,
954                 original_nb,
955                 bs.init_cache,
956             )
957
958         return bs
959
960     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
961     # 1s where tokens should be generated. The others are kept
962     # unchanged.
963
964     def masked_inplace_autoregression(
965         self,
966         input_src,
967         ar_mask_src,
968         forbidden_tokens=None,
969         deterministic_synthesis=False,
970     ):
971         input = input_src.to(self.readout.f.weight.device)
972         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
973         to_generate = (ar_mask.sum(0) > 0).nonzero()
974         if to_generate.min() > 0:
975             self(
976                 BracketedSequence(input, 0, to_generate.min(), True)
977             )  # Needed to initialize the model's cache
978         for s in range(to_generate.min(), to_generate.max() + 1):
979             output = self(BracketedSequence(input, s, 1, s == 0)).x
980             logits = output[:, s]
981             if forbidden_tokens is not None:
982                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
983             if deterministic_synthesis:
984                 t_next = logits.argmax(1)
985             else:
986                 dist = torch.distributions.categorical.Categorical(logits=logits)
987                 t_next = dist.sample()
988             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
989
990         input_src.copy_(input)
991
992     def reset_inner_loss(self):
993         for m in self.modules():
994             if m is not self and hasattr(m, "reset_inner_loss"):
995                 m.reset_inner_loss()
996
997     def get_inner_loss(self):
998         l = torch.tensor([0.0], device=self.readout.f.weight.device)
999         for m in self.modules():
1000             if m is not self and hasattr(m, "get_inner_loss"):
1001                 l += m.get_inner_loss()
1002         return l
1003
1004     def record_attention(self, v=True):
1005         for m in self.modules():
1006             if isinstance(m, QKVAttention):
1007                 m.record_attention = v
1008
1009     def retrieve_attention(self):
1010         a = []
1011         for m in self.modules():
1012             if isinstance(m, QKVAttention):
1013                 a.append(m.a)
1014         return a
1015
1016
1017 ######################################################################
1018
1019 if __name__ == "__main__":
1020     print("Basic check.")
1021
1022     m = Caterpillar(
1023         dim_model=4,
1024         dim_qk=3,
1025         dim_v=7,
1026         nb_heads=1,
1027         caterpillar_length=7,
1028         caterpillar_height=3,
1029         attention_dropout=0.0,
1030     )
1031
1032     m.reset_inner_loss()
1033     x = torch.randn(1, 21 + 2 * 7, 4)
1034     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1035     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1036     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1037     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1038     print((y1 - y2).abs().max())
1039     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1040     exit(0)
1041
1042     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1043
1044     vocabulary_size = 128
1045     x = torch.randint(vocabulary_size, (6, 1024))
1046
1047     model = MyGPT(
1048         vocabulary_size=vocabulary_size,
1049         dim_model=512,
1050         dim_keys=64,
1051         dim_hidden=2048,
1052         nb_heads=8,
1053         nb_lines=128,
1054         nb_blocks=12,
1055         dropout=0.1,
1056         causal=True,
1057     )
1058
1059     x = x.to(device)
1060     model.to(device)
1061
1062     import time, sys
1063
1064     # import torchvision.models as models
1065     # from torch.profiler import profile, record_function, ProfilerActivity
1066
1067     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1068     # with record_function("model_inference"):
1069
1070     model.eval()
1071     for i in range(3):
1072         start_time = time.perf_counter()
1073         for k in range(10):
1074             model(BracketedSequence(x))
1075         duration = time.perf_counter() - start_time
1076         print(duration)
1077         sys.stdout.flush()
1078
1079     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1080     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1081
1082     # print("##############################################################")
1083     # y2 = torch.randn_like(y1)
1084     # for s in range(x.size(1)):
1085     # z = model(BracketedSequence(x, s, 1))
1086     # y2[:, s : s + 1] = z.slice()
1087
1088     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1089
1090 ######################################################################