Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150     with torch.no_grad():
151         s_A, s_X = 0, 0
152         for t in range(X.size(dim) - 1, 0, -1):
153             delta = (grad_Y[t] - s_A) / A[t].grad
154             s_A += A[t].grad * delta
155             A[t].grad = delta
156             delta = (grad_Y[t] - s_X) / X[t].grad
157             s_X += X[t].grad * delta
158             X[t].grad = delta
159
160
161 def pscan_shape(A, X, Y_init):
162     s = X.size()
163     A = A.reshape(-1, s[-2])
164     X = X.reshape(-1, s[-2], s[-1])
165
166     if Y_init is None:
167         Y_init = X.new_zeros(X.size(0), s[-1])
168     else:
169         Y_init = Y_init.reshape(-1, s[-1])
170
171     Y = pscan.pscan(A, X, Y_init).reshape(s)
172
173     return Y
174
175
176 def nsum_shape(X, Y_init):
177     s = X.size()
178     X = X.reshape(-1, s[-2], s[-1])  # ntd
179
180     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
181     result = []
182
183     for k in range(X.size(1)):
184         Y = Y + X[:, k]
185         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
186         result.append(Y)
187
188     return torch.cat(result, dim=1).reshape(s)
189
190
191 ##############################
192
193
194 class DumbRec(nn.Module):
195     def __init__(
196         self,
197         dim_model,
198         dim_qk,
199         dim_v,
200         nb_heads,
201         nb_lines,
202         attention_dropout=0.0,
203         len_max=1e5,
204         logger=print,
205         args=None,
206     ):
207         super().__init__()
208
209         def randw(*d):
210             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
211
212         self.nb_lines = nb_lines
213         self.attention_dropout = attention_dropout
214
215         self.k_star = randw(nb_lines, dim_qk)
216
217         self.w_qw = randw(nb_heads, dim_qk, dim_model)
218         self.w_qr = randw(nb_heads, dim_qk, dim_model)
219         # self.w_k = randw(nb_heads, dim_qk, dim_model)
220         self.w_v = randw(nb_heads, dim_v, dim_model)
221         self.w_o = randw(dim_v * nb_heads, dim_model)
222
223     def reset_inner_loss(self):
224         self.acc_attention = 0
225         self.acc_nb = 0
226
227     def get_inner_loss(self):
228         warnings.warn("l2 regularization", RuntimeWarning)
229         return (self.acc_attention / self.acc_nb).pow(2).sum()
230         # return torch.tensor([0], device=self.w_qw.device)
231
232     def forward(self, bs):
233         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
234
235         if bs.init_cache:
236             self.rec_v = x_q.new_zeros(
237                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
238             )
239             # self.rec_k = x_q.new_zeros(
240             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
241             # )
242             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
243
244         ######################################################################
245         # Prepare the keys
246
247         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
248
249         warnings.warn("rotating key barrel", RuntimeWarning)
250         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251         t_barrel = torch.arange(t0, t1, device=k_star.device)
252         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
253         l_barrel = (
254             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
255         ) % k_star.size(0)
256         k_star = k_star[l_barrel, t_barrel]
257
258         ######################################################################
259         # Compute the recurrent state
260
261         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
262
263         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
265
266         aw = torch.einsum(
267             "nhtd,ltd->nhlt",
268             qw,
269             k_star,
270         ) / math.sqrt(self.w_qw.size(1))
271
272         aw = aw.softmax(dim=2)  # nhlt
273
274         if self.train:
275             self.acc_attention += aw.sum(dim=(0, 1, 3))
276             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
277
278         aw = F.dropout(aw, self.attention_dropout, self.training)
279
280         A = 1 - aw.sum(dim=1)  # nlt
281
282         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
284
285         if t0 == 0:
286             V0 = None
287             # K0 = None
288         else:
289             V0 = self.rec_v[:, :, t0 - 1]
290             # K0 = self.rec_k[:, :, t0 - 1]
291
292         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
294
295         ######################################################################
296         # compute the readout
297
298         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
299
300         ar = torch.einsum(
301             "nhtd,ld->nhlt",
302             qr,
303             # self.rec_k[:, :, t0:t1],
304             self.k_star,
305         ) / math.sqrt(self.w_qr.size(1))
306
307         ar = ar.softmax(dim=2)  # nhlt
308
309         ar = F.dropout(ar, self.attention_dropout, self.training)
310
311         y = torch.einsum(
312             "nhlt,nltd->nthd",
313             ar,
314             self.rec_v[:, :, t0:t1],
315         ).flatten(2)
316
317         self.cache_y[:, t0:t1] = y @ self.w_o
318
319         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
320
321
322 ##############################
323
324
325 class KVRec(nn.Module):
326     def __init__(
327         self,
328         dim_model,
329         dim_qk,
330         dim_v,
331         nb_heads,
332         nb_lines,
333         attention_dropout=0.0,
334         len_max=1e5,
335         logger=print,
336         args=None,
337     ):
338         super().__init__()
339
340         def randw(*d):
341             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
342
343         self.nb_lines = nb_lines
344         self.attention_dropout = attention_dropout
345
346         self.k_star = randw(nb_lines, dim_qk)
347
348         self.w_qw = randw(nb_heads, dim_qk, dim_model)
349         self.w_qr = randw(nb_heads, dim_qk, dim_model)
350         self.w_k = randw(nb_heads, dim_qk, dim_model)
351         self.w_v = randw(nb_heads, dim_v, dim_model)
352         self.w_o = randw(dim_v * nb_heads, dim_model)
353
354     def reset_inner_loss(self):
355         self.acc_attention = 0
356         self.acc_nb = 0
357
358     def get_inner_loss(self):
359         warnings.warn("l2 regularization", RuntimeWarning)
360         return (self.acc_attention / self.acc_nb).pow(2).sum()
361         # return torch.tensor([0], device=self.w_qw.device)
362         # warnings.warn("side regularization", RuntimeWarning)
363         # return (
364         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
365         # )
366         # return torch.tensor([0], device=self.w_qw.device)
367
368     def forward(self, bs):
369         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
370
371         if bs.init_cache:
372             self.rec_v = x_q.new_zeros(
373                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
374             )
375             self.rec_k = x_q.new_zeros(
376                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
377             )
378             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
379
380         ######################################################################
381         # Prepare the keys
382
383         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
384
385         warnings.warn("rotating key barrel", RuntimeWarning)
386         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387         t_barrel = torch.arange(t0, t1, device=k_star.device)
388         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
389         l_barrel = (
390             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
391         ) % k_star.size(0)
392         k_star = k_star[l_barrel, t_barrel]
393
394         ######################################################################
395         # Compute the recurrent state
396
397         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
398
399         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
401
402         aw = torch.einsum(
403             "nhtd,ltd->nhlt",
404             qw,
405             k_star,
406         ) / math.sqrt(self.w_qw.size(1))
407
408         aw = aw.softmax(dim=2)  # nhlt
409
410         if self.train:
411             # We want all the memory lines to be used similarly
412             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
413             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
414
415         aw = F.dropout(aw, self.attention_dropout, self.training)
416
417         A = 1 - aw.sum(dim=1)  # nlt
418
419         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
421
422         if t0 == 0:
423             V0 = None
424             K0 = None
425         else:
426             V0 = self.rec_v[:, :, t0 - 1]
427             K0 = self.rec_k[:, :, t0 - 1]
428
429         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
431
432         ######################################################################
433         # compute the readout
434
435         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
436
437         ar = torch.einsum(
438             "nhtd,nltd->nhlt",
439             qr,
440             self.rec_k[:, :, t0:t1],
441         ) / math.sqrt(self.w_qr.size(1))
442
443         ar = ar.softmax(dim=2)  # nhlt
444
445         ar = F.dropout(ar, self.attention_dropout, self.training)
446
447         y = torch.einsum(
448             "nhlt,nltd->nthd",
449             ar,
450             self.rec_v[:, :, t0:t1],
451         ).flatten(2)
452
453         self.cache_y[:, t0:t1] = y @ self.w_o
454
455         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
456
457
458 ##############################
459
460
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
464
465
466 def moving_window(x, dim, win_dim, win_size):
467     size, stride = x.size(), x.stride()
468     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469     size = size[:win_dim] + (win_size,) + size[win_dim:]
470     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
471
472     return x.as_strided(size=size, stride=stride)
473
474
475 ##############################
476
477
478 class Caterpillar(nn.Module):
479     def __init__(
480         self,
481         dim_model,
482         dim_qk,
483         dim_v,
484         nb_heads,
485         caterpillar_length,
486         caterpillar_height,
487         attention_dropout=0.0,
488         len_max=1e5,
489         logger=print,
490         args=None,
491     ):
492         super().__init__()
493
494         warnings.warn("Caterpillar", RuntimeWarning)
495
496         def randw(*d, amplitude=None):
497             if amplitude is None:
498                 amplitude = 1 / math.sqrt(d[-1])
499             return nn.Parameter(amplitude * torch.randn(*d))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         self.gate_dropout_proba = args.gate_dropout_proba
506         self.gate_dropout_sync = args.gate_dropout_sync
507         self.gate_dropout_replace = args.gate_dropout_replace
508
509         ######################################################################
510
511         default_bg = -math.log(caterpillar_height - 1)
512         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
513         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
514
515         self.w_K = randw(nb_heads, dim_qk, dim_model)
516         self.w_V = randw(nb_heads, dim_v, dim_model)
517         self.w_Q = randw(nb_heads, dim_qk, dim_model)
518         self.w_O = randw(dim_v * nb_heads, dim_model)
519
520         self.init_K_rec = randw(
521             caterpillar_height,
522             caterpillar_length,
523             dim_qk,
524         )
525         self.init_V_rec = randw(
526             caterpillar_height,
527             caterpillar_length,
528             dim_v,
529         )
530
531     # def reset_inner_loss(self):
532     # self.acc_attention = 0
533     # self.acc_nb = 0
534
535     # def get_inner_loss(self):
536     # warnings.warn("l2 regularization", RuntimeWarning)
537     # return (self.acc_attention / self.acc_nb).pow(2).sum()
538     # return torch.tensor([0], device=self.w_Q.device)
539
540     def forward(self, bs):
541         # Dimensions to make the source a bit clearer, that's needed
542
543         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
544
545         N = bs.x.size(0)
546         T = bs.x.size(1)
547         H = self.w_V.size(0)
548         DV = self.w_V.size(1)
549         DK = self.w_K.size(1)
550         DM = self.w_O.size(1)
551         R = self.caterpillar_height
552         L = self.caterpillar_length
553
554         assert (
555             t0 >= L and (t1 - t0) % L == 0
556         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
557
558         # We cache values to deal efficiently with auto-regression
559
560         if bs.init_cache:
561             self.rec_V = X.new_zeros(N, R, T, DV)
562             self.rec_K = X.new_zeros(N, R, T, DK)
563             # We start the recurrent sequences with optimizable
564             # initial values. No idea if it helps.
565             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
566             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
567
568             self.cache_Y = X.new_zeros(N, T, DM)
569
570         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
571         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
572
573         ######################################################################
574         # Compute the recurrent state
575
576         # This is the Gating sequence that modulates the storing of
577         # the new key and value in the R pairs of the current
578         # stack. There are R independent gating values, which means
579         # that the current K/V may be stored in multiple pairs of the
580         # recurrent state, or not at all.
581
582         G = (
583             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
584         ).sigmoid()
585
586         # Clip the gating to avoid values greater than 1 when several
587         # heads hit the same row
588
589         G = G / G.sum(1, keepdim=True).clamp(min=1)
590
591         ######################################################################
592
593         def recurrence(G, V, K):
594             # We prepare the arguments for the parallel scan
595
596             A = 1 - G.sum(1)
597
598             gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
599             gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
600
601             # We start from cached values, which matters in inference
602
603             init_rec_V = self.rec_V[:, :, t0 - L : t0]
604             init_rec_K = self.rec_K[:, :, t0 - L : t0]
605
606             # Here there is a trick: Since the stack at position t is
607             # computed by updating that at position t-L, the parallel
608             # scan operates with a period of L. To do so we split the
609             # sequence indexing in two axes, the second of size L, and
610             # run the parallel scan using the first as the sequence index.
611
612             A = A.unflatten(2, (-1, L))
613             gated_V = gated_V.unflatten(2, (-1, L))
614             gated_K = gated_K.unflatten(2, (-1, L))
615
616             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
617             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
618
619             return next_V, next_K
620
621         #################################################################
622
623         next_V, next_K = recurrence(G, V, K)
624
625         if self.training and self.gate_dropout_proba > 0.0:
626             # G is NxHxRxT where r is the caterpillar's row.
627
628             warnings.warn("gate dropout", RuntimeWarning)
629
630             if self.gate_dropout_sync:
631                 shape_kill = (N, 1, 1)
632             else:
633                 shape_kill = (N, H, R)
634
635             # Pick a point in each of the NxHxR timeline and set this
636             # entry and the following to 1
637             kill = (
638                 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
639                 == 0
640             ).cumsum(dim=3)
641
642             # Keep these mask for only some of the NxHxR
643             kill = kill * (
644                 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
645             )
646
647             # The coefficient to keep are the complementary
648             mask = 1 - kill
649
650             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
651
652             if self.gate_dropout_replace:
653                 next_V = next_V.detach()
654                 next_K = next_K.detach()
655
656             next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
657                 1 - self.gate_dropout_proba
658             )
659             next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
660                 1 - self.gate_dropout_proba
661             )
662
663         self.rec_V[:, :, t0:t1] = next_V
664         self.rec_K[:, :, t0:t1] = next_K
665
666         ######################################################################
667         # compute the readout
668
669         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
670
671         # We build tensors NxHxTxRxL where N is the sample index, H
672         # the head, T the time, R the row in the caterpillar, and L
673         # the column in the caterpillar
674
675         windowed_V = moving_window(
676             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
677         )
678
679         windowed_K = moving_window(
680             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
681         )
682
683         # We have an attention score for each of the RxL values
684
685         ar = torch.einsum(
686             "nhtd,nrtld->nhtrl",
687             Q,
688             windowed_K,
689         ) / math.sqrt(DK)
690
691         # softmax can operate only on one dimension, hence the
692         # flattening
693
694         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
695
696         ar = F.dropout(ar, self.attention_dropout, self.training)
697
698         # Compute the output for each head, flatten to concatenate
699
700         Y = torch.einsum(
701             "nhtfl,nftld->nthd",
702             ar,
703             windowed_V,
704         ).flatten(2)
705
706         # Compute the final output
707
708         self.cache_Y[:, t0:t1] = Y @ self.w_O
709
710         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
711
712
713 ##############################
714
715
716 class QKVAttention(nn.Module):
717     def __init__(
718         self,
719         dim_model,
720         dim_qk,
721         dim_v,
722         nb_heads=1,
723         causal=False,
724         attention_dropout=0.0,
725         logger=print,
726         args=None,
727     ):
728         super().__init__()
729
730         def randw(*d):
731             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
732
733         self.causal = causal
734         self.attention_dropout = attention_dropout
735         self.record_attention = False
736
737         self.w_q = randw(nb_heads, dim_qk, dim_model)
738         self.w_k = randw(nb_heads, dim_qk, dim_model)
739         self.w_v = randw(nb_heads, dim_v, dim_model)
740         self.w_o = randw(dim_v * nb_heads, dim_model)
741
742     def forward(self, bs):
743         x_q = bs.x
744
745         assert (
746             self.causal or bs.complete()
747         ), "Partial evaluation is only possible for causal models"
748
749         if bs.init_cache:
750             self.cache_k = x_q.new_zeros(
751                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
752             )
753             self.cache_v = x_q.new_zeros(
754                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
755             )
756             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
757
758         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
759
760         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
761             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
762         )
763         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
764             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
765         )
766
767         a = torch.einsum(
768             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
769         ) / math.sqrt(self.w_q.size(1))
770
771         if self.causal:
772             if bs.init_cache:
773                 self.cache_attzero = (
774                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
775                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
776                 )
777             a = a.masked_fill(
778                 self.cache_attzero[
779                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
780                 ],
781                 float("-inf"),
782             )
783
784         a = a.softmax(dim=3)
785
786         if self.record_attention:
787             self.a = a
788
789         a = F.dropout(a, self.attention_dropout, self.training)
790
791         y = torch.einsum(
792             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
793         ).flatten(2)
794
795         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
796
797         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
798
799
800 ##############################
801
802
803 class MyGPT(nn.Module):
804     def __init__(
805         self,
806         vocabulary_size,
807         dim_model,
808         dim_keys,
809         dim_hidden,
810         nb_heads,
811         nb_blocks,
812         nb_lines=None,
813         caterpillar_height=None,
814         causal=False,
815         dropout=0.0,
816         len_max=1e5,
817         attention_layer="kvrec",
818         logger=print,
819         args=None,
820     ):
821         super().__init__()
822
823         assert attention_layer in {
824             "mha",
825             "dumbrec",
826             "kvrec",
827             "caterpillar",
828         }, f"Unknown attention operator {attention_layer}."
829
830         if attention_layer == "caterpillar":
831             assert nb_lines % caterpillar_height == 0
832             self.caterpillar_length = nb_lines // caterpillar_height
833             self.caterpillar_height = caterpillar_height
834         else:
835             self.caterpillar_length = -1
836             self.caterpillar_height = -1
837
838         assert dim_model % nb_heads == 0
839
840         self.embedding = nn.Sequential(
841             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
842             AddPositionalEncoding(len_max),
843         )
844
845         trunk_blocks = []
846
847         def attlayer():
848             if attention_layer == "mha":
849                 return QKVAttention(
850                     dim_model=dim_model,
851                     dim_qk=dim_keys,
852                     dim_v=dim_model // nb_heads,
853                     nb_heads=nb_heads,
854                     causal=causal,
855                     attention_dropout=dropout,
856                     logger=logger,
857                     args=args,
858                 )
859             elif attention_layer == "dumbrec":
860                 return DumbRec(
861                     dim_model=dim_model,
862                     dim_qk=dim_keys,
863                     dim_v=dim_model // nb_heads,
864                     nb_heads=nb_heads,
865                     nb_lines=nb_lines,
866                     attention_dropout=dropout,
867                     logger=logger,
868                     args=args,
869                 )
870             elif attention_layer == "kvrec":
871                 return KVRec(
872                     dim_model=dim_model,
873                     dim_qk=dim_keys,
874                     dim_v=dim_model // nb_heads,
875                     nb_heads=nb_heads,
876                     nb_lines=nb_lines,
877                     attention_dropout=dropout,
878                     logger=logger,
879                     args=args,
880                 )
881             elif attention_layer == "caterpillar":
882                 return Caterpillar(
883                     dim_model=dim_model,
884                     dim_qk=dim_keys,
885                     dim_v=dim_model // nb_heads,
886                     nb_heads=nb_heads,
887                     caterpillar_length=self.caterpillar_length,
888                     caterpillar_height=self.caterpillar_height,
889                     attention_dropout=dropout,
890                     logger=logger,
891                     args=args,
892                 )
893             else:
894                 raise ValueError(f"Unknown attention type {attention_layer}.")
895
896         for b in range(nb_blocks):
897             trunk_blocks += [
898                 WithResidual(
899                     CacheWrapper(nn.LayerNorm((dim_model,))),
900                     attlayer(),
901                 ),
902                 WithResidual(
903                     CacheWrapper(
904                         nn.LayerNorm((dim_model,)),
905                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
906                         nn.ReLU(),
907                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
908                         nn.Dropout(dropout),
909                     ),
910                 ),
911             ]
912
913         self.trunk = nn.Sequential(*trunk_blocks)
914
915         self.readout = CacheWrapper(
916             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
917         )
918
919         with torch.no_grad():
920             for m in self.modules():
921                 if isinstance(m, nn.Embedding):
922                     m.weight.normal_(mean=0, std=2e-2)
923                 elif isinstance(m, nn.LayerNorm):
924                     m.bias.zero_()
925                     m.weight.fill_(1.0)
926
927         self.reset_inner_loss()
928
929     def forward(self, bs):
930         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
931
932         # To make the code simpler in the Caterpillar layer, we pad
933         # here. It's unclear if/how much it hurts computationaly by
934         # increasing the sequence length for the other layers
935
936         if self.caterpillar_length > 0:
937             original_nb = bs.nb
938             if bs.nb % self.caterpillar_length > 0:
939                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
940
941             bs = BracketedSequence(
942                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
943                 bs.first + self.caterpillar_length,
944                 bs.nb,
945                 bs.init_cache,
946             )
947
948         bs = self.embedding(bs)
949         bs = self.trunk(bs)
950         bs = self.readout(bs)
951
952         if self.caterpillar_length > 0:
953             bs = BracketedSequence(
954                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
955                 bs.first - self.caterpillar_length,
956                 original_nb,
957                 bs.init_cache,
958             )
959
960         return bs
961
962     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
963     # 1s where tokens should be generated. The others are kept
964     # unchanged.
965
966     def masked_inplace_autoregression(
967         self,
968         input_src,
969         ar_mask_src,
970         forbidden_tokens=None,
971         deterministic_synthesis=False,
972     ):
973         input = input_src.to(self.readout.f.weight.device)
974         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
975         to_generate = (ar_mask.sum(0) > 0).nonzero()
976         if to_generate.min() > 0:
977             self(
978                 BracketedSequence(input, 0, to_generate.min(), True)
979             )  # Needed to initialize the model's cache
980         for s in range(to_generate.min(), to_generate.max() + 1):
981             output = self(BracketedSequence(input, s, 1, s == 0)).x
982             logits = output[:, s]
983             if forbidden_tokens is not None:
984                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
985             if deterministic_synthesis:
986                 t_next = logits.argmax(1)
987             else:
988                 dist = torch.distributions.categorical.Categorical(logits=logits)
989                 t_next = dist.sample()
990             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
991
992         input_src.copy_(input)
993
994     def reset_inner_loss(self):
995         for m in self.modules():
996             if m is not self and hasattr(m, "reset_inner_loss"):
997                 m.reset_inner_loss()
998
999     def get_inner_loss(self):
1000         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1001         for m in self.modules():
1002             if m is not self and hasattr(m, "get_inner_loss"):
1003                 l += m.get_inner_loss()
1004         return l
1005
1006     def record_attention(self, v=True):
1007         for m in self.modules():
1008             if isinstance(m, QKVAttention):
1009                 m.record_attention = v
1010
1011     def retrieve_attention(self):
1012         a = []
1013         for m in self.modules():
1014             if isinstance(m, QKVAttention):
1015                 a.append(m.a)
1016         return a
1017
1018
1019 ######################################################################
1020
1021 if __name__ == "__main__":
1022     print("Basic check.")
1023
1024     m = Caterpillar(
1025         dim_model=4,
1026         dim_qk=3,
1027         dim_v=7,
1028         nb_heads=1,
1029         caterpillar_length=7,
1030         caterpillar_height=3,
1031         attention_dropout=0.0,
1032     )
1033
1034     m.reset_inner_loss()
1035     x = torch.randn(1, 21 + 2 * 7, 4)
1036     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1037     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1038     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1039     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1040     print((y1 - y2).abs().max())
1041     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1042     exit(0)
1043
1044     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1045
1046     vocabulary_size = 128
1047     x = torch.randint(vocabulary_size, (6, 1024))
1048
1049     model = MyGPT(
1050         vocabulary_size=vocabulary_size,
1051         dim_model=512,
1052         dim_keys=64,
1053         dim_hidden=2048,
1054         nb_heads=8,
1055         nb_lines=128,
1056         nb_blocks=12,
1057         dropout=0.1,
1058         causal=True,
1059     )
1060
1061     x = x.to(device)
1062     model.to(device)
1063
1064     import time, sys
1065
1066     # import torchvision.models as models
1067     # from torch.profiler import profile, record_function, ProfilerActivity
1068
1069     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1070     # with record_function("model_inference"):
1071
1072     model.eval()
1073     for i in range(3):
1074         start_time = time.perf_counter()
1075         for k in range(10):
1076             model(BracketedSequence(x))
1077         duration = time.perf_counter() - start_time
1078         print(duration)
1079         sys.stdout.flush()
1080
1081     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1082     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1083
1084     # print("##############################################################")
1085     # y2 = torch.randn_like(y1)
1086     # for s in range(x.size(1)):
1087     # z = model(BracketedSequence(x, s, 1))
1088     # y2[:, s : s + 1] = z.slice()
1089
1090     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1091
1092 ######################################################################