Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150     with torch.no_grad():
151         s_A, s_X = 0, 0
152         for t in range(X.size(dim) - 1, 0, -1):
153             delta = (grad_Y[t] - s_A) / A[t].grad
154             s_A += A[t].grad * delta
155             A[t].grad = delta
156             delta = (grad_Y[t] - s_X) / X[t].grad
157             s_X += X[t].grad * delta
158             X[t].grad = delta
159
160
161 def pscan_shape(A, X, Y_init):
162     s = X.size()
163     A = A.reshape(-1, s[-2])
164     X = X.reshape(-1, s[-2], s[-1])
165
166     if Y_init is None:
167         Y_init = X.new_zeros(X.size(0), s[-1])
168     else:
169         Y_init = Y_init.reshape(-1, s[-1])
170
171     Y = pscan.pscan(A, X, Y_init).reshape(s)
172
173     return Y
174
175
176 def nsum_shape(X, Y_init):
177     s = X.size()
178     X = X.reshape(-1, s[-2], s[-1])  # ntd
179
180     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
181     result = []
182
183     for k in range(X.size(1)):
184         Y = Y + X[:, k]
185         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
186         result.append(Y)
187
188     return torch.cat(result, dim=1).reshape(s)
189
190
191 ##############################
192
193
194 class DumbRec(nn.Module):
195     def __init__(
196         self,
197         dim_model,
198         dim_qk,
199         dim_v,
200         nb_heads,
201         nb_lines,
202         attention_dropout=0.0,
203         len_max=1e5,
204         logger=print,
205         args,
206     ):
207         super().__init__()
208
209         def randw(*d):
210             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
211
212         self.nb_lines = nb_lines
213         self.attention_dropout = attention_dropout
214
215         self.k_star = randw(nb_lines, dim_qk)
216
217         self.w_qw = randw(nb_heads, dim_qk, dim_model)
218         self.w_qr = randw(nb_heads, dim_qk, dim_model)
219         # self.w_k = randw(nb_heads, dim_qk, dim_model)
220         self.w_v = randw(nb_heads, dim_v, dim_model)
221         self.w_o = randw(dim_v * nb_heads, dim_model)
222
223     def reset_inner_loss(self):
224         self.acc_attention = 0
225         self.acc_nb = 0
226
227     def get_inner_loss(self):
228         warnings.warn("l2 regularization", RuntimeWarning)
229         return (self.acc_attention / self.acc_nb).pow(2).sum()
230         # return torch.tensor([0], device=self.w_qw.device)
231
232     def forward(self, bs):
233         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
234
235         if bs.init_cache:
236             self.rec_v = x_q.new_zeros(
237                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
238             )
239             # self.rec_k = x_q.new_zeros(
240             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
241             # )
242             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
243
244         ######################################################################
245         # Prepare the keys
246
247         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
248
249         warnings.warn("rotating key barrel", RuntimeWarning)
250         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251         t_barrel = torch.arange(t0, t1, device=k_star.device)
252         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
253         l_barrel = (
254             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
255         ) % k_star.size(0)
256         k_star = k_star[l_barrel, t_barrel]
257
258         ######################################################################
259         # Compute the recurrent state
260
261         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
262
263         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
265
266         aw = torch.einsum(
267             "nhtd,ltd->nhlt",
268             qw,
269             k_star,
270         ) / math.sqrt(self.w_qw.size(1))
271
272         aw = aw.softmax(dim=2)  # nhlt
273
274         if self.train:
275             self.acc_attention += aw.sum(dim=(0, 1, 3))
276             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
277
278         aw = F.dropout(aw, self.attention_dropout, self.training)
279
280         A = 1 - aw.sum(dim=1)  # nlt
281
282         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
284
285         if t0 == 0:
286             V0 = None
287             # K0 = None
288         else:
289             V0 = self.rec_v[:, :, t0 - 1]
290             # K0 = self.rec_k[:, :, t0 - 1]
291
292         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
294
295         ######################################################################
296         # compute the readout
297
298         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
299
300         ar = torch.einsum(
301             "nhtd,ld->nhlt",
302             qr,
303             # self.rec_k[:, :, t0:t1],
304             self.k_star,
305         ) / math.sqrt(self.w_qr.size(1))
306
307         ar = ar.softmax(dim=2)  # nhlt
308
309         ar = F.dropout(ar, self.attention_dropout, self.training)
310
311         y = torch.einsum(
312             "nhlt,nltd->nthd",
313             ar,
314             self.rec_v[:, :, t0:t1],
315         ).flatten(2)
316
317         self.cache_y[:, t0:t1] = y @ self.w_o
318
319         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
320
321
322 ##############################
323
324
325 class KVRec(nn.Module):
326     def __init__(
327         self,
328         dim_model,
329         dim_qk,
330         dim_v,
331         nb_heads,
332         nb_lines,
333         attention_dropout=0.0,
334         len_max=1e5,
335         logger=print,
336         args,
337     ):
338         super().__init__()
339
340         def randw(*d):
341             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
342
343         self.nb_lines = nb_lines
344         self.attention_dropout = attention_dropout
345
346         self.k_star = randw(nb_lines, dim_qk)
347
348         self.w_qw = randw(nb_heads, dim_qk, dim_model)
349         self.w_qr = randw(nb_heads, dim_qk, dim_model)
350         self.w_k = randw(nb_heads, dim_qk, dim_model)
351         self.w_v = randw(nb_heads, dim_v, dim_model)
352         self.w_o = randw(dim_v * nb_heads, dim_model)
353
354     def reset_inner_loss(self):
355         self.acc_attention = 0
356         self.acc_nb = 0
357
358     def get_inner_loss(self):
359         warnings.warn("l2 regularization", RuntimeWarning)
360         return (self.acc_attention / self.acc_nb).pow(2).sum()
361         # return torch.tensor([0], device=self.w_qw.device)
362         # warnings.warn("side regularization", RuntimeWarning)
363         # return (
364         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
365         # )
366         # return torch.tensor([0], device=self.w_qw.device)
367
368     def forward(self, bs):
369         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
370
371         if bs.init_cache:
372             self.rec_v = x_q.new_zeros(
373                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
374             )
375             self.rec_k = x_q.new_zeros(
376                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
377             )
378             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
379
380         ######################################################################
381         # Prepare the keys
382
383         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
384
385         warnings.warn("rotating key barrel", RuntimeWarning)
386         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387         t_barrel = torch.arange(t0, t1, device=k_star.device)
388         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
389         l_barrel = (
390             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
391         ) % k_star.size(0)
392         k_star = k_star[l_barrel, t_barrel]
393
394         ######################################################################
395         # Compute the recurrent state
396
397         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
398
399         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
401
402         aw = torch.einsum(
403             "nhtd,ltd->nhlt",
404             qw,
405             k_star,
406         ) / math.sqrt(self.w_qw.size(1))
407
408         aw = aw.softmax(dim=2)  # nhlt
409
410         if self.train:
411             # We want all the memory lines to be used similarly
412             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
413             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
414
415         aw = F.dropout(aw, self.attention_dropout, self.training)
416
417         A = 1 - aw.sum(dim=1)  # nlt
418
419         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
421
422         if t0 == 0:
423             V0 = None
424             K0 = None
425         else:
426             V0 = self.rec_v[:, :, t0 - 1]
427             K0 = self.rec_k[:, :, t0 - 1]
428
429         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
431
432         ######################################################################
433         # compute the readout
434
435         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
436
437         ar = torch.einsum(
438             "nhtd,nltd->nhlt",
439             qr,
440             self.rec_k[:, :, t0:t1],
441         ) / math.sqrt(self.w_qr.size(1))
442
443         ar = ar.softmax(dim=2)  # nhlt
444
445         ar = F.dropout(ar, self.attention_dropout, self.training)
446
447         y = torch.einsum(
448             "nhlt,nltd->nthd",
449             ar,
450             self.rec_v[:, :, t0:t1],
451         ).flatten(2)
452
453         self.cache_y[:, t0:t1] = y @ self.w_o
454
455         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
456
457
458 ##############################
459
460
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
464
465
466 def moving_window(x, dim, win_dim, win_size):
467     size, stride = x.size(), x.stride()
468     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469     size = size[:win_dim] + (win_size,) + size[win_dim:]
470     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
471
472     return x.as_strided(size=size, stride=stride)
473
474
475 ##############################
476
477
478 class Caterpillar(nn.Module):
479     def __init__(
480         self,
481         dim_model,
482         dim_qk,
483         dim_v,
484         nb_heads,
485         caterpillar_length,
486         caterpillar_height,
487         attention_dropout=0.0,
488         len_max=1e5,
489         logger=print,
490         args,
491     ):
492         super().__init__()
493
494         warnings.warn("Caterpillar", RuntimeWarning)
495
496         def randw(*d, amplitude=None):
497             if amplitude is None:
498                 amplitude = 1 / math.sqrt(d[-1])
499             return nn.Parameter(amplitude * torch.randn(*d))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         self.gate_dropout_proba = args.gate_dropout_proba
506         self.gate_dropout_sync = args.gate_dropout_sync
507
508         ######################################################################
509
510         default_bg = -math.log(caterpillar_height - 1)
511         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
512         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
513
514         self.w_K = randw(nb_heads, dim_qk, dim_model)
515         self.w_V = randw(nb_heads, dim_v, dim_model)
516         self.w_Q = randw(nb_heads, dim_qk, dim_model)
517         self.w_O = randw(dim_v * nb_heads, dim_model)
518
519         self.init_K_rec = randw(
520             caterpillar_height,
521             caterpillar_length,
522             dim_qk,
523         )
524         self.init_V_rec = randw(
525             caterpillar_height,
526             caterpillar_length,
527             dim_v,
528         )
529
530     def reset_inner_loss(self):
531         self.acc_attention = 0
532         self.acc_nb = 0
533
534     def get_inner_loss(self):
535         # warnings.warn("l2 regularization", RuntimeWarning)
536         # return (self.acc_attention / self.acc_nb).pow(2).sum()
537         return torch.tensor([0], device=self.w_Q.device)
538
539     def forward(self, bs):
540         # Dimensions to make the source a bit clearer, that's needed
541
542         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
543
544         N = bs.x.size(0)
545         T = bs.x.size(1)
546         H = self.w_V.size(0)
547         DV = self.w_V.size(1)
548         DK = self.w_K.size(1)
549         DM = self.w_O.size(1)
550         R = self.caterpillar_height
551         L = self.caterpillar_length
552
553         assert (
554             t0 >= L and (t1 - t0) % L == 0
555         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
556
557         # We cache values to deal efficiently with auto-regression
558
559         if bs.init_cache:
560             self.rec_V = X.new_zeros(N, R, T, DV)
561             self.rec_K = X.new_zeros(N, R, T, DK)
562             # We start the recurrent sequences with optimizable
563             # initial values. No idea if it helps.
564             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
565             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
566
567             self.cache_Y = X.new_zeros(N, T, DM)
568
569         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
570         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
571
572         ######################################################################
573         # Compute the recurrent state
574
575         # This is the Gating sequence that modulates the storing of
576         # the new key and value in the R pairs of the current
577         # stack. There are R independent gating values, which means
578         # that the current K/V may be stored in multiple pairs of the
579         # recurrent state, or not at all.
580
581         G = (
582             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
583         ).sigmoid()
584
585         # Clip the gating to avoid values greater than 1 when several
586         # heads hit the same row
587
588         G = G / G.sum(1, keepdim=True).clamp(min=1)
589
590         ######################################################################
591
592         def recurrence(G, V, K):
593             # We prepare the arguments for the parallel scan
594
595             A = 1 - G.sum(1)
596
597             gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
598             gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
599
600             # We start from cached values, which matters in inference
601
602             init_rec_V = self.rec_V[:, :, t0 - L : t0]
603             init_rec_K = self.rec_K[:, :, t0 - L : t0]
604
605             # Here there is a trick: Since the stack at position t is
606             # computed by updating that at position t-L, the parallel
607             # scan operates with a period of L. To do so we split the
608             # sequence indexing in two axes, the second of size L, and
609             # run the parallel scan using the first as the sequence index.
610
611             A = A.unflatten(2, (-1, L))
612             gated_V = gated_V.unflatten(2, (-1, L))
613             gated_K = gated_K.unflatten(2, (-1, L))
614
615             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
616             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
617
618             next_V = next_V.flatten(2, 3)
619             next_K = next_K.flatten(2, 3)
620
621             return next_V, next_K
622
623         #################################################################
624
625         next_V, next_K = recurrence(G, V, K)
626
627         if self.training and self.gate_dropout_proba > 0.0:
628             # G is NxHxRxT where r is the caterpillar's row.
629
630             warnings.warn("gate dropout", RuntimeWarning)
631
632             # Pick a point in each of the NxHxR timeline and set this
633             # entry and the following to 1
634             kill = (
635                 torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
636             ).cumsum(dim=3)
637
638             # Keep these mask for only some of the NxHxR
639             kill = kill * (
640                 torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
641             )
642
643             # The coefficient to keep are the complementary
644             mask = 1 - kill
645
646             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
647
648             next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
649                 1 - self.gate_dropout_proba
650             )
651             next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
652                 1 - self.gate_dropout_proba
653             )
654
655         self.rec_V[:, :, t0:t1] = next_V
656         self.rec_K[:, :, t0:t1] = next_K
657
658         ######################################################################
659         # compute the readout
660
661         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
662
663         # We build tensors NxHxTxRxL where N is the sample index, H
664         # the head, T the time, R the row in the caterpillar, and L
665         # the column in the caterpillar
666
667         windowed_V = moving_window(
668             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
669         )
670
671         windowed_K = moving_window(
672             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
673         )
674
675         # We have an attention score for each of the RxL values
676
677         ar = torch.einsum(
678             "nhtd,nrtld->nhtrl",
679             Q,
680             windowed_K,
681         ) / math.sqrt(DK)
682
683         # softmax can operate only on one dimension, hence the
684         # flattening
685
686         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
687
688         ar = F.dropout(ar, self.attention_dropout, self.training)
689
690         # Compute the output for each head, flatten to concatenate
691
692         Y = torch.einsum(
693             "nhtfl,nftld->nthd",
694             ar,
695             windowed_V,
696         ).flatten(2)
697
698         # Compute the final output
699
700         self.cache_Y[:, t0:t1] = Y @ self.w_O
701
702         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
703
704
705 ##############################
706
707
708 class QKVAttention(nn.Module):
709     def __init__(
710         self,
711         dim_model,
712         dim_qk,
713         dim_v,
714         nb_heads=1,
715         causal=False,
716         attention_dropout=0.0,
717         logger=print,
718         args,
719     ):
720         super().__init__()
721
722         def randw(*d):
723             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
724
725         self.causal = causal
726         self.attention_dropout = attention_dropout
727         self.record_attention = False
728
729         self.w_q = randw(nb_heads, dim_qk, dim_model)
730         self.w_k = randw(nb_heads, dim_qk, dim_model)
731         self.w_v = randw(nb_heads, dim_v, dim_model)
732         self.w_o = randw(dim_v * nb_heads, dim_model)
733
734     def forward(self, bs):
735         x_q = bs.x
736
737         assert (
738             self.causal or bs.complete()
739         ), "Partial evaluation is only possible for causal models"
740
741         if bs.init_cache:
742             self.cache_k = x_q.new_zeros(
743                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
744             )
745             self.cache_v = x_q.new_zeros(
746                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
747             )
748             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
749
750         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
751
752         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
753             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
754         )
755         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
756             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
757         )
758
759         a = torch.einsum(
760             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
761         ) / math.sqrt(self.w_q.size(1))
762
763         if self.causal:
764             if bs.init_cache:
765                 self.cache_attzero = (
766                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
767                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
768                 )
769             a = a.masked_fill(
770                 self.cache_attzero[
771                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
772                 ],
773                 float("-inf"),
774             )
775
776         a = a.softmax(dim=3)
777
778         if self.record_attention:
779             self.a = a
780
781         a = F.dropout(a, self.attention_dropout, self.training)
782
783         y = torch.einsum(
784             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
785         ).flatten(2)
786
787         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
788
789         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
790
791
792 ##############################
793
794
795 class MyGPT(nn.Module):
796     def __init__(
797         self,
798         vocabulary_size,
799         dim_model,
800         dim_keys,
801         dim_hidden,
802         nb_heads,
803         nb_blocks,
804         nb_lines=None,
805         caterpillar_height=None,
806         causal=False,
807         dropout=0.0,
808         len_max=1e5,
809         attention_layer="kvrec",
810         logger=print,
811         args,
812     ):
813         super().__init__()
814
815         assert attention_layer in {
816             "mha",
817             "dumbrec",
818             "kvrec",
819             "caterpillar",
820         }, f"Unknown attention operator {attention_layer}."
821
822         if attention_layer == "caterpillar":
823             assert nb_lines % caterpillar_height == 0
824             self.caterpillar_length = nb_lines // caterpillar_height
825             self.caterpillar_height = caterpillar_height
826         else:
827             self.caterpillar_length = -1
828             self.caterpillar_height = -1
829
830         assert dim_model % nb_heads == 0
831
832         self.embedding = nn.Sequential(
833             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
834             AddPositionalEncoding(len_max),
835         )
836
837         trunk_blocks = []
838
839         def attlayer():
840             if attention_layer == "mha":
841                 return QKVAttention(
842                     dim_model=dim_model,
843                     dim_qk=dim_keys,
844                     dim_v=dim_model // nb_heads,
845                     nb_heads=nb_heads,
846                     causal=causal,
847                     attention_dropout=dropout,
848                     logger=logger,
849                     args,
850                 )
851             elif attention_layer == "dumbrec":
852                 return DumbRec(
853                     dim_model=dim_model,
854                     dim_qk=dim_keys,
855                     dim_v=dim_model // nb_heads,
856                     nb_heads=nb_heads,
857                     nb_lines=nb_lines,
858                     attention_dropout=dropout,
859                     logger=logger,
860                     args,
861                 )
862             elif attention_layer == "kvrec":
863                 return KVRec(
864                     dim_model=dim_model,
865                     dim_qk=dim_keys,
866                     dim_v=dim_model // nb_heads,
867                     nb_heads=nb_heads,
868                     nb_lines=nb_lines,
869                     attention_dropout=dropout,
870                     logger=logger,
871                     args,
872                 )
873             elif attention_layer == "caterpillar":
874                 return Caterpillar(
875                     dim_model=dim_model,
876                     dim_qk=dim_keys,
877                     dim_v=dim_model // nb_heads,
878                     nb_heads=nb_heads,
879                     caterpillar_length=self.caterpillar_length,
880                     caterpillar_height=self.caterpillar_height,
881                     attention_dropout=dropout,
882                     logger=logger,
883                     args,
884                 )
885             else:
886                 raise ValueError(f"Unknown attention type {attention_layer}.")
887
888         for b in range(nb_blocks):
889             trunk_blocks += [
890                 WithResidual(
891                     CacheWrapper(nn.LayerNorm((dim_model,))),
892                     attlayer(),
893                 ),
894                 WithResidual(
895                     CacheWrapper(
896                         nn.LayerNorm((dim_model,)),
897                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
898                         nn.ReLU(),
899                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
900                         nn.Dropout(dropout),
901                     ),
902                 ),
903             ]
904
905         self.trunk = nn.Sequential(*trunk_blocks)
906
907         self.readout = CacheWrapper(
908             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
909         )
910
911         with torch.no_grad():
912             for m in self.modules():
913                 if isinstance(m, nn.Embedding):
914                     m.weight.normal_(mean=0, std=2e-2)
915                 elif isinstance(m, nn.LayerNorm):
916                     m.bias.zero_()
917                     m.weight.fill_(1.0)
918
919         self.reset_inner_loss()
920
921     def forward(self, bs):
922         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
923
924         # To make the code simpler in the Caterpillar layer, we pad
925         # here. It's unclear if/how much it hurts computationaly by
926         # increasing the sequence length for the other layers
927
928         if self.caterpillar_length > 0:
929             original_nb = bs.nb
930             if bs.nb % self.caterpillar_length > 0:
931                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
932
933             bs = BracketedSequence(
934                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
935                 bs.first + self.caterpillar_length,
936                 bs.nb,
937                 bs.init_cache,
938             )
939
940         bs = self.embedding(bs)
941         bs = self.trunk(bs)
942         bs = self.readout(bs)
943
944         if self.caterpillar_length > 0:
945             bs = BracketedSequence(
946                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
947                 bs.first - self.caterpillar_length,
948                 original_nb,
949                 bs.init_cache,
950             )
951
952         return bs
953
954     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
955     # 1s where tokens should be generated. The others are kept
956     # unchanged.
957
958     def masked_inplace_autoregression(
959         self,
960         input_src,
961         ar_mask_src,
962         forbidden_tokens=None,
963         deterministic_synthesis=False,
964     ):
965         input = input_src.to(self.readout.f.weight.device)
966         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
967         to_generate = (ar_mask.sum(0) > 0).nonzero()
968         if to_generate.min() > 0:
969             self(
970                 BracketedSequence(input, 0, to_generate.min(), True)
971             )  # Needed to initialize the model's cache
972         for s in range(to_generate.min(), to_generate.max() + 1):
973             output = self(BracketedSequence(input, s, 1, s == 0)).x
974             logits = output[:, s]
975             if forbidden_tokens is not None:
976                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
977             if deterministic_synthesis:
978                 t_next = logits.argmax(1)
979             else:
980                 dist = torch.distributions.categorical.Categorical(logits=logits)
981                 t_next = dist.sample()
982             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
983
984         input_src.copy_(input)
985
986     def reset_inner_loss(self):
987         for m in self.modules():
988             if m is not self and hasattr(m, "reset_inner_loss"):
989                 m.reset_inner_loss()
990
991     def get_inner_loss(self):
992         l = torch.tensor([0.0], device=self.readout.f.weight.device)
993         for m in self.modules():
994             if m is not self and hasattr(m, "get_inner_loss"):
995                 l += m.get_inner_loss()
996         return l
997
998     def record_attention(self, v=True):
999         for m in self.modules():
1000             if isinstance(m, QKVAttention):
1001                 m.record_attention = v
1002
1003     def retrieve_attention(self):
1004         a = []
1005         for m in self.modules():
1006             if isinstance(m, QKVAttention):
1007                 a.append(m.a)
1008         return a
1009
1010
1011 ######################################################################
1012
1013 if __name__ == "__main__":
1014     print("Basic check.")
1015
1016     m = Caterpillar(
1017         dim_model=4,
1018         dim_qk=3,
1019         dim_v=7,
1020         nb_heads=1,
1021         caterpillar_length=7,
1022         caterpillar_height=3,
1023         attention_dropout=0.0,
1024     )
1025
1026     m.reset_inner_loss()
1027     x = torch.randn(1, 21 + 2 * 7, 4)
1028     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1029     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1030     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1031     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1032     print((y1 - y2).abs().max())
1033     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1034     exit(0)
1035
1036     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1037
1038     vocabulary_size = 128
1039     x = torch.randint(vocabulary_size, (6, 1024))
1040
1041     model = MyGPT(
1042         vocabulary_size=vocabulary_size,
1043         dim_model=512,
1044         dim_keys=64,
1045         dim_hidden=2048,
1046         nb_heads=8,
1047         nb_lines=128,
1048         nb_blocks=12,
1049         dropout=0.1,
1050         causal=True,
1051     )
1052
1053     x = x.to(device)
1054     model.to(device)
1055
1056     import time, sys
1057
1058     # import torchvision.models as models
1059     # from torch.profiler import profile, record_function, ProfilerActivity
1060
1061     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1062     # with record_function("model_inference"):
1063
1064     model.eval()
1065     for i in range(3):
1066         start_time = time.perf_counter()
1067         for k in range(10):
1068             model(BracketedSequence(x))
1069         duration = time.perf_counter() - start_time
1070         print(duration)
1071         sys.stdout.flush()
1072
1073     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1074     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1075
1076     # print("##############################################################")
1077     # y2 = torch.randn_like(y1)
1078     # for s in range(x.size(1)):
1079     # z = model(BracketedSequence(x, s, 1))
1080     # y2[:, s : s + 1] = z.slice()
1081
1082     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1083
1084 ######################################################################