Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # from blanket import blanket
25
26 # import memload
27
28 ######################################################################
29
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # steps to compute.
32
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
36 # no holes.
37 #
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
42 #
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
45
46
47 class BracketedSequence:
48     def __init__(self, x, first=None, nb=None, init_cache=None):
49         self.x = x
50         assert (first is None and nb is None and init_cache is None) or (
51             first is not None and nb is not None and init_cache is not None
52         )
53
54         self.first = 0 if first is None else first
55         self.nb = x.size(1) if nb is None else nb
56         self.init_cache = True if init_cache is None else init_cache
57
58     def slice(self):
59         return self.x[:, self.first : self.first + self.nb]
60
61     def complete(self):
62         return self.first == 0 and self.nb == self.x.size(1)
63
64
65 ######################################################################
66
67
68 class CacheWrapper(nn.Module):
69     def __init__(self, *f):
70         super().__init__()
71         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
72
73     def forward(self, bs):
74         if bs.init_cache:
75             y = self.f(bs.slice())
76             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77             self.cache_y[:, bs.first : bs.first + bs.nb] = y
78         else:
79             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80             assert bs.first + bs.nb <= self.cache_y.size(1)
81             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
82
83         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84
85
86 ##############################
87
88
89 class WithResidual(nn.Module):
90     def __init__(self, *f):
91         super().__init__()
92         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
93
94     def forward(self, bs):
95         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
96
97
98 ##############################
99
100
101 class AddPositionalEncoding(nn.Module):
102     def __init__(self, len_max):
103         super().__init__()
104         self.len_max = len_max
105
106     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
107
108     def forward(self, bs):
109         if bs.init_cache:
110             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
111                 :, None
112             ]
113             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
114                 None, :
115             ]
116             k = j % 2
117             self.pe = torch.sin(
118                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
119             )
120             self.cache_y = bs.x.new(bs.x.size())
121
122         self.cache_y[:, bs.first : bs.first + bs.nb] = (
123             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
124         )
125
126         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
127
128
129 import pscan
130
131 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
132
133
134 def pscan_dim(A, X, Y_init, dim=-2):
135     s = X.size()
136     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
137
138     A = A.reshape(a, T, *s[dim + 1 : -1])
139     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
140
141     if Y_init is None:
142         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
143     else:
144         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
145
146     Y = pscan.pscan(A, X, Y_init).reshape(s)
147
148     return Y
149
150
151 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
152     with torch.no_grad():
153         s_A, s_X = 0, 0
154         for t in range(X.size(dim) - 1, 0, -1):
155             delta = (grad_Y[t] - s_A) / A[t].grad
156             s_A += A[t].grad * delta
157             A[t].grad = delta
158             delta = (grad_Y[t] - s_X) / X[t].grad
159             s_X += X[t].grad * delta
160             X[t].grad = delta
161
162
163 def pscan_shape(A, X, Y_init):
164     s = X.size()
165     A = A.reshape(-1, s[-2])
166     X = X.reshape(-1, s[-2], s[-1])
167
168     if Y_init is None:
169         Y_init = X.new_zeros(X.size(0), s[-1])
170     else:
171         Y_init = Y_init.reshape(-1, s[-1])
172
173     Y = pscan.pscan(A, X, Y_init).reshape(s)
174
175     return Y
176
177
178 def nsum_shape(X, Y_init):
179     s = X.size()
180     X = X.reshape(-1, s[-2], s[-1])  # ntd
181
182     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
183     result = []
184
185     for k in range(X.size(1)):
186         Y = Y + X[:, k]
187         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
188         result.append(Y)
189
190     return torch.cat(result, dim=1).reshape(s)
191
192
193 ##############################
194
195
196 class DumbRec(nn.Module):
197     def __init__(
198         self,
199         dim_model,
200         dim_qk,
201         dim_v,
202         nb_heads,
203         nb_lines,
204         attention_dropout=0.0,
205         len_max=1e5,
206         logger=print,
207         args=None,
208     ):
209         super().__init__()
210
211         def randw(*d):
212             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
213
214         self.nb_lines = nb_lines
215         self.attention_dropout = attention_dropout
216
217         self.k_star = randw(nb_lines, dim_qk)
218
219         self.w_qw = randw(nb_heads, dim_qk, dim_model)
220         self.w_qr = randw(nb_heads, dim_qk, dim_model)
221         # self.w_k = randw(nb_heads, dim_qk, dim_model)
222         self.w_v = randw(nb_heads, dim_v, dim_model)
223         self.w_o = randw(dim_v * nb_heads, dim_model)
224
225     def reset_inner_loss(self):
226         self.acc_attention = 0
227         self.acc_nb = 0
228
229     def get_inner_loss(self):
230         warnings.warn("l2 regularization", RuntimeWarning)
231         return (self.acc_attention / self.acc_nb).pow(2).sum()
232         # return torch.tensor([0], device=self.w_qw.device)
233
234     def forward(self, bs):
235         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
236
237         if bs.init_cache:
238             self.rec_v = x_q.new_zeros(
239                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
240             )
241             # self.rec_k = x_q.new_zeros(
242             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
243             # )
244             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
245
246         ######################################################################
247         # Prepare the keys
248
249         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
250
251         warnings.warn("rotating key barrel", RuntimeWarning)
252         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
253         t_barrel = torch.arange(t0, t1, device=k_star.device)
254         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
255         l_barrel = (
256             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
257         ) % k_star.size(0)
258         k_star = k_star[l_barrel, t_barrel]
259
260         ######################################################################
261         # Compute the recurrent state
262
263         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
264
265         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
266         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
267
268         aw = torch.einsum(
269             "nhtd,ltd->nhlt",
270             qw,
271             k_star,
272         ) / math.sqrt(self.w_qw.size(1))
273
274         aw = aw.softmax(dim=2)  # nhlt
275
276         if self.train:
277             self.acc_attention += aw.sum(dim=(0, 1, 3))
278             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
279
280         aw = F.dropout(aw, self.attention_dropout, self.training)
281
282         A = 1 - aw.sum(dim=1)  # nlt
283
284         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
285         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
286
287         if t0 == 0:
288             V0 = None
289             # K0 = None
290         else:
291             V0 = self.rec_v[:, :, t0 - 1]
292             # K0 = self.rec_k[:, :, t0 - 1]
293
294         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
295         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
296
297         ######################################################################
298         # compute the readout
299
300         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
301
302         ar = torch.einsum(
303             "nhtd,ld->nhlt",
304             qr,
305             # self.rec_k[:, :, t0:t1],
306             self.k_star,
307         ) / math.sqrt(self.w_qr.size(1))
308
309         ar = ar.softmax(dim=2)  # nhlt
310
311         ar = F.dropout(ar, self.attention_dropout, self.training)
312
313         y = torch.einsum(
314             "nhlt,nltd->nthd",
315             ar,
316             self.rec_v[:, :, t0:t1],
317         ).flatten(2)
318
319         self.cache_y[:, t0:t1] = y @ self.w_o
320
321         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
322
323
324 ##############################
325
326
327 class KVRec(nn.Module):
328     def __init__(
329         self,
330         dim_model,
331         dim_qk,
332         dim_v,
333         nb_heads,
334         nb_lines,
335         attention_dropout=0.0,
336         len_max=1e5,
337         logger=print,
338         args=None,
339     ):
340         super().__init__()
341
342         def randw(*d):
343             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
344
345         self.nb_lines = nb_lines
346         self.attention_dropout = attention_dropout
347
348         self.k_star = randw(nb_lines, dim_qk)
349
350         self.w_qw = randw(nb_heads, dim_qk, dim_model)
351         self.w_qr = randw(nb_heads, dim_qk, dim_model)
352         self.w_k = randw(nb_heads, dim_qk, dim_model)
353         self.w_v = randw(nb_heads, dim_v, dim_model)
354         self.w_o = randw(dim_v * nb_heads, dim_model)
355
356     def reset_inner_loss(self):
357         self.acc_attention = 0
358         self.acc_nb = 0
359
360     def get_inner_loss(self):
361         warnings.warn("l2 regularization", RuntimeWarning)
362         return (self.acc_attention / self.acc_nb).pow(2).sum()
363         # return torch.tensor([0], device=self.w_qw.device)
364         # warnings.warn("side regularization", RuntimeWarning)
365         # return (
366         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
367         # )
368         # return torch.tensor([0], device=self.w_qw.device)
369
370     def forward(self, bs):
371         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
372
373         if bs.init_cache:
374             self.rec_v = x_q.new_zeros(
375                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
376             )
377             self.rec_k = x_q.new_zeros(
378                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
379             )
380             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
381
382         ######################################################################
383         # Prepare the keys
384
385         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
386
387         warnings.warn("rotating key barrel", RuntimeWarning)
388         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
389         t_barrel = torch.arange(t0, t1, device=k_star.device)
390         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
391         l_barrel = (
392             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
393         ) % k_star.size(0)
394         k_star = k_star[l_barrel, t_barrel]
395
396         ######################################################################
397         # Compute the recurrent state
398
399         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
400
401         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
402         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
403
404         aw = torch.einsum(
405             "nhtd,ltd->nhlt",
406             qw,
407             k_star,
408         ) / math.sqrt(self.w_qw.size(1))
409
410         aw = aw.softmax(dim=2)  # nhlt
411
412         if self.train:
413             # We want all the memory lines to be used similarly
414             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
415             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
416
417         aw = F.dropout(aw, self.attention_dropout, self.training)
418
419         A = 1 - aw.sum(dim=1)  # nlt
420
421         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
422         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
423
424         if t0 == 0:
425             V0 = None
426             K0 = None
427         else:
428             V0 = self.rec_v[:, :, t0 - 1]
429             K0 = self.rec_k[:, :, t0 - 1]
430
431         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
432         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
433
434         ######################################################################
435         # compute the readout
436
437         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
438
439         ar = torch.einsum(
440             "nhtd,nltd->nhlt",
441             qr,
442             self.rec_k[:, :, t0:t1],
443         ) / math.sqrt(self.w_qr.size(1))
444
445         ar = ar.softmax(dim=2)  # nhlt
446
447         ar = F.dropout(ar, self.attention_dropout, self.training)
448
449         y = torch.einsum(
450             "nhlt,nltd->nthd",
451             ar,
452             self.rec_v[:, :, t0:t1],
453         ).flatten(2)
454
455         self.cache_y[:, t0:t1] = y @ self.w_o
456
457         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
458
459
460 ##############################
461
462
463 # Returns a tensor with an additional index at rank win_dim, that move
464 # along the same dimension as dim, on a domain {0...win_size-1}, and
465 # dim is restricted on a domain reduced by win_size-1 values.
466
467
468 def moving_window(x, dim, win_dim, win_size):
469     size, stride = x.size(), x.stride()
470     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
471     size = size[:win_dim] + (win_size,) + size[win_dim:]
472     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
473
474     return x.as_strided(size=size, stride=stride)
475
476
477 ##############################
478
479
480 class Caterpillar(nn.Module):
481     def __init__(
482         self,
483         dim_model,
484         dim_qk,
485         dim_v,
486         nb_heads,
487         caterpillar_length,
488         caterpillar_height,
489         attention_dropout=0.0,
490         len_max=1e5,
491         logger=print,
492         args=None,
493     ):
494         super().__init__()
495
496         warnings.warn("Caterpillar", RuntimeWarning)
497
498         def randw(*d, factor=1):
499             return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         ######################################################################
506
507         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
508         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
509
510         self.w_K = randw(nb_heads, dim_qk, dim_model)
511         self.w_V = randw(nb_heads, dim_v, dim_model)
512         self.w_Q = randw(nb_heads, dim_qk, dim_model)
513         self.w_O = randw(dim_v * nb_heads, dim_model)
514
515         self.init_K_rec = randw(
516             caterpillar_height,
517             caterpillar_length,
518             dim_qk,
519         )
520         self.init_V_rec = randw(
521             caterpillar_height,
522             caterpillar_length,
523             dim_v,
524         )
525
526     # def reset_inner_loss(self):
527     # self.acc_attention = 0
528     # self.acc_nb = 0
529
530     # def get_inner_loss(self):
531     # warnings.warn("l2 regularization", RuntimeWarning)
532     # return (self.acc_attention / self.acc_nb).pow(2).sum()
533     # return torch.tensor([0], device=self.w_Q.device)
534
535     def forward(self, bs):
536         # Dimensions to make the source a bit clearer, that's needed
537
538         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
539
540         N = bs.x.size(0)
541         T = bs.x.size(1)
542         H = self.w_V.size(0)
543         DV = self.w_V.size(1)
544         DK = self.w_K.size(1)
545         DM = self.w_O.size(1)
546         R = self.caterpillar_height
547         L = self.caterpillar_length
548
549         assert (
550             t0 >= L and (t1 - t0) % L == 0
551         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
552
553         # We cache values to deal efficiently with auto-regression
554
555         if bs.init_cache:
556             self.rec_V = X.new_zeros(N, R, T, DV)
557             self.rec_K = X.new_zeros(N, R, T, DK)
558             # We start the recurrent sequences with optimizable
559             # initial values. No idea if it helps.
560             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
561             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
562
563             self.cache_Y = X.new_zeros(N, T, DM)
564
565         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
566         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
567
568         ######################################################################
569         # Compute the recurrent state
570
571         # This is the Gating sequence that modulates the storing of
572         # the new key and value in the R pairs of the current
573         # stack. There are R independent gating values, which means
574         # that the current K/V may be stored in multiple pairs of the
575         # recurrent state, or not at all.
576
577         G = (
578             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
579         ).sigmoid()
580
581         # Clip the gating to avoid values greater than 1 when several
582         # heads hit the same row
583
584         G = G / G.sum(1, keepdim=True).clamp(min=1)
585
586         ######################################################################
587
588         A = 1 - G.sum(dim=1)
589
590         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
591         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
592
593         # We start from cached values, which matters in inference
594
595         init_rec_V = self.rec_V[:, :, t0 - L : t0]
596         init_rec_K = self.rec_K[:, :, t0 - L : t0]
597
598         # Here there is a trick: Since the stack at position t is
599         # computed by updating that at position t-L, the parallel
600         # scan operates with a period of L. To do so we split the
601         # sequence indexing in two axes, the second of size L, and
602         # run the parallel scan using the first as the sequence index.
603
604         A = A.unflatten(2, (-1, L))
605         gated_V = gated_V.unflatten(2, (-1, L))
606         gated_K = gated_K.unflatten(2, (-1, L))
607
608         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
609         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
610
611         self.rec_V[:, :, t0:t1] = next_V
612         self.rec_K[:, :, t0:t1] = next_K
613
614         ######################################################################
615         # compute the readout
616
617         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
618
619         # Q = blanket(Q)
620
621         # We build tensors NxHxTxRxL where N is the sample index, H
622         # the head, T the time, R the row in the caterpillar, and L
623         # the column in the caterpillar
624
625         windowed_V = moving_window(
626             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
627         )
628
629         windowed_K = moving_window(
630             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
631         )
632
633         # We have an attention score for each of the RxL values
634
635         ar = torch.einsum(
636             "nhtd,nrtld->nhtrl",
637             Q,
638             windowed_K,
639         ) / math.sqrt(DK)
640
641         # softmax can operate only on one dimension, hence the
642         # flattening
643
644         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
645
646         ar = F.dropout(ar, self.attention_dropout, self.training)
647
648         # Compute the output for each head, flatten to concatenate
649
650         Y = torch.einsum(
651             "nhtfl,nftld->nthd",
652             ar,
653             windowed_V,
654         ).flatten(2)
655
656         self.cache_Y[:, t0:t1] = Y @ self.w_O
657
658         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
659
660
661 ##############################
662
663
664 class QKVAttention(nn.Module):
665     def __init__(
666         self,
667         dim_model,
668         dim_qk,
669         dim_v,
670         nb_heads=1,
671         causal=False,
672         horizon=None,
673         attention_dropout=0.0,
674         logger=print,
675         args=None,
676     ):
677         super().__init__()
678
679         def randw(*d):
680             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
681
682         self.causal = causal
683         self.horizon = horizon
684         self.attention_dropout = attention_dropout
685         self.record_attention = False
686
687         self.w_q = randw(nb_heads, dim_qk, dim_model)
688         self.w_k = randw(nb_heads, dim_qk, dim_model)
689         self.w_v = randw(nb_heads, dim_v, dim_model)
690         self.w_o = randw(dim_v * nb_heads, dim_model)
691
692     def forward(self, bs):
693         x_q = bs.x
694
695         assert (
696             self.causal or bs.complete()
697         ), "Partial evaluation is only possible for causal models"
698
699         if bs.init_cache:
700             self.cache_k = x_q.new_zeros(
701                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
702             )
703             self.cache_v = x_q.new_zeros(
704                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
705             )
706             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
707
708         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
709
710         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
711             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
712         )
713         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
714             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
715         )
716
717         a = torch.einsum(
718             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
719         ) / math.sqrt(self.w_q.size(1))
720
721         if self.causal:
722             if bs.init_cache:
723                 self.cache_attzero = (
724                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
725                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
726                 )
727
728                 if self.horizon is not None:
729                     self.cache_attzero = torch.logical_or(
730                         self.cache_attzero,
731                         torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
732                         >= torch.arange(x_q.size(1), device=q.device)[
733                             None, None, None, :
734                         ]
735                         + self.horizon,
736                     )
737
738             a = a.masked_fill(
739                 self.cache_attzero[
740                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
741                 ],
742                 float("-inf"),
743             )
744
745         a = a.softmax(dim=3)
746
747         if self.record_attention:
748             self.a = a
749
750         a = F.dropout(a, self.attention_dropout, self.training)
751
752         y = torch.einsum(
753             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
754         ).flatten(2)
755
756         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
757
758         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
759
760
761 ##############################
762
763
764 class MyGPT(nn.Module):
765     def __init__(
766         self,
767         vocabulary_size,
768         dim_model,
769         dim_keys,
770         dim_hidden,
771         nb_heads,
772         nb_blocks,
773         nb_lines=None,
774         caterpillar_height=None,
775         causal=False,
776         dropout=0.0,
777         len_max=1e5,
778         attention_layer="caterpillar",
779         logger=print,
780         args=None,
781     ):
782         super().__init__()
783
784         assert attention_layer in {
785             "mha",
786             "dumbrec",
787             "kvrec",
788             "caterpillar",
789             "attcat",
790         }, f"Unknown attention operator {attention_layer}."
791
792         if attention_layer == "caterpillar" or attention_layer == "attcat":
793             assert nb_lines % caterpillar_height == 0
794             self.caterpillar_length = nb_lines // caterpillar_height
795             self.caterpillar_height = caterpillar_height
796         else:
797             self.caterpillar_length = -1
798             self.caterpillar_height = -1
799
800         assert dim_model % nb_heads == 0
801
802         self.embedding = nn.Sequential(
803             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
804             AddPositionalEncoding(len_max),
805         )
806
807         trunk_blocks = []
808
809         def attlayer():
810             if attention_layer == "mha":
811                 return WithResidual(
812                     CacheWrapper(nn.LayerNorm((dim_model,))),
813                     QKVAttention(
814                         dim_model=dim_model,
815                         dim_qk=dim_keys,
816                         dim_v=dim_model // nb_heads,
817                         nb_heads=nb_heads,
818                         causal=causal,
819                         attention_dropout=dropout,
820                         logger=logger,
821                         args=args,
822                     ),
823                 )
824             elif attention_layer == "dumbrec":
825                 return WithResidual(
826                     CacheWrapper(nn.LayerNorm((dim_model,))),
827                     DumbRec(
828                         dim_model=dim_model,
829                         dim_qk=dim_keys,
830                         dim_v=dim_model // nb_heads,
831                         nb_heads=nb_heads,
832                         nb_lines=nb_lines,
833                         attention_dropout=dropout,
834                         logger=logger,
835                         args=args,
836                     ),
837                 )
838             elif attention_layer == "kvrec":
839                 return WithResidual(
840                     CacheWrapper(nn.LayerNorm((dim_model,))),
841                     KVRec(
842                         dim_model=dim_model,
843                         dim_qk=dim_keys,
844                         dim_v=dim_model // nb_heads,
845                         nb_heads=nb_heads,
846                         nb_lines=nb_lines,
847                         attention_dropout=dropout,
848                         logger=logger,
849                         args=args,
850                     ),
851                 )
852             elif attention_layer == "caterpillar":
853                 return WithResidual(
854                     CacheWrapper(nn.LayerNorm((dim_model,))),
855                     Caterpillar(
856                         dim_model=dim_model,
857                         dim_qk=dim_keys,
858                         dim_v=dim_model // nb_heads,
859                         nb_heads=nb_heads,
860                         caterpillar_length=self.caterpillar_length,
861                         caterpillar_height=self.caterpillar_height,
862                         attention_dropout=dropout,
863                         logger=logger,
864                         args=args,
865                     ),
866                 )
867             elif attention_layer == "attcat":
868                 return nn.Sequential(
869                     WithResidual(
870                         CacheWrapper(nn.LayerNorm((dim_model,))),
871                         QKVAttention(
872                             dim_model=dim_model,
873                             dim_qk=dim_keys,
874                             dim_v=dim_model // nb_heads,
875                             nb_heads=nb_heads,
876                             causal=causal,
877                             horizon=self.caterpillar_length,
878                             attention_dropout=dropout,
879                             logger=logger,
880                             args=args,
881                         ),
882                     ),
883                     WithResidual(
884                         CacheWrapper(nn.LayerNorm((dim_model,))),
885                         Caterpillar(
886                             dim_model=dim_model,
887                             dim_qk=dim_keys,
888                             dim_v=dim_model // nb_heads,
889                             nb_heads=nb_heads,
890                             caterpillar_length=self.caterpillar_length,
891                             caterpillar_height=self.caterpillar_height,
892                             attention_dropout=dropout,
893                             logger=logger,
894                             args=args,
895                         ),
896                     ),
897                 )
898             else:
899                 raise ValueError(f"Unknown attention type {attention_layer}.")
900
901         for b in range(nb_blocks):
902             trunk_blocks += [
903                 attlayer(),
904                 WithResidual(
905                     CacheWrapper(
906                         nn.LayerNorm((dim_model,)),
907                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
908                         nn.ReLU(),
909                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
910                         nn.Dropout(dropout),
911                     ),
912                 ),
913             ]
914
915         self.trunk = nn.Sequential(*trunk_blocks)
916
917         self.readout = CacheWrapper(
918             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
919         )
920
921         with torch.no_grad():
922             for m in self.modules():
923                 if isinstance(m, nn.Embedding):
924                     m.weight.normal_(mean=0, std=2e-2)
925                 elif isinstance(m, nn.LayerNorm):
926                     m.bias.zero_()
927                     m.weight.fill_(1.0)
928
929         self.reset_inner_loss()
930
931     def forward(self, bs):
932         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
933
934         # To make the code simpler in the Caterpillar layer, we pad
935         # here. It's unclear if/how much it hurts computationaly by
936         # increasing the sequence length for the other layers
937
938         if self.caterpillar_length > 0:
939             original_nb = bs.nb
940             if bs.nb % self.caterpillar_length > 0:
941                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
942
943             bs = BracketedSequence(
944                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
945                 bs.first + self.caterpillar_length,
946                 bs.nb,
947                 bs.init_cache,
948             )
949
950         bs = self.embedding(bs)
951         bs = self.trunk(bs)
952         bs = self.readout(bs)
953
954         if self.caterpillar_length > 0:
955             bs = BracketedSequence(
956                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
957                 bs.first - self.caterpillar_length,
958                 original_nb,
959                 bs.init_cache,
960             )
961
962         return bs
963
964     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
965     # 1s where tokens should be generated. The others are kept
966     # unchanged.
967
968     def masked_inplace_autoregression(
969         self,
970         input_src,
971         ar_mask_src,
972         forbidden_tokens=None,
973         deterministic_synthesis=False,
974     ):
975         input = input_src.to(self.readout.f.weight.device)
976         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
977         to_generate = (ar_mask.sum(0) > 0).nonzero()
978         if to_generate.min() > 0:
979             self(
980                 BracketedSequence(input, 0, to_generate.min(), True)
981             )  # Needed to initialize the model's cache
982         for s in range(to_generate.min(), to_generate.max() + 1):
983             output = self(BracketedSequence(input, s, 1, s == 0)).x
984             logits = output[:, s]
985             if forbidden_tokens is not None:
986                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
987             if deterministic_synthesis:
988                 t_next = logits.argmax(1)
989             else:
990                 dist = torch.distributions.categorical.Categorical(logits=logits)
991                 t_next = dist.sample()
992             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
993
994         input_src.copy_(input)
995
996     def reset_inner_loss(self):
997         for m in self.modules():
998             if m is not self and hasattr(m, "reset_inner_loss"):
999                 m.reset_inner_loss()
1000
1001     def get_inner_loss(self):
1002         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1003         for m in self.modules():
1004             if m is not self and hasattr(m, "get_inner_loss"):
1005                 l += m.get_inner_loss()
1006         return l
1007
1008     def record_attention(self, v=True):
1009         for m in self.modules():
1010             if isinstance(m, QKVAttention):
1011                 m.record_attention = v
1012
1013     def retrieve_attention(self):
1014         a = []
1015         for m in self.modules():
1016             if isinstance(m, QKVAttention):
1017                 a.append(m.a)
1018         return a
1019
1020
1021 ######################################################################
1022
1023 if __name__ == "__main__":
1024     import argparse
1025
1026     import numpy as np
1027     import matplotlib.pyplot as plt
1028     import matplotlib.collections as mc
1029
1030     args = argparse.Namespace(
1031         gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1032     )
1033
1034     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1035
1036     dim_model, dim_keys, nb_heads = 512, 64, 1
1037     dropout = 0.1
1038
1039     caterpillar = Caterpillar(
1040         dim_model=dim_model,
1041         dim_qk=dim_keys,
1042         dim_v=dim_model // nb_heads,
1043         nb_heads=nb_heads,
1044         caterpillar_length=16,
1045         caterpillar_height=32,
1046         attention_dropout=dropout,
1047         args=args,
1048     ).to(device)
1049
1050     qkv = QKVAttention(
1051         dim_model=dim_model,
1052         dim_qk=dim_keys,
1053         dim_v=dim_model // nb_heads,
1054         nb_heads=nb_heads,
1055         causal=True,
1056         attention_dropout=dropout,
1057         args=args,
1058     ).to(device)
1059
1060     linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1061
1062     x = torch.randn(1, 256, dim_model)
1063
1064     x = x.to(device)
1065     x.requires_grad_()
1066
1067     ######################################################################
1068
1069     fig = plt.figure()
1070     fig.set_figheight(6)
1071     fig.set_figwidth(8)
1072
1073     ax = fig.add_subplot(1, 1, 1)
1074
1075     # ax.set_xlim(-1.5, 1.5)
1076     # ax.set_ylim(-1.5, 1.5)
1077     # ax.set(aspect=1)
1078     # ax.spines.right.set_visible(False)
1079     # ax.spines.top.set_visible(False)
1080
1081     # dt = 0.01
1082     # t = np.arange(dt, 20.0, dt)
1083     # ax.semilogx(t, np.exp(-t / 5.0))
1084     # ax.grid()
1085     ax.set_yscale("log")
1086
1087     ######################################################################
1088
1089     for label, model, thickness in [
1090         ("nn.Linear", linear, 0.2),
1091         ("mygpy.QKVAttention", qkv, 1),
1092         ("mygpt.Caterpillar", caterpillar, 2),
1093     ]:
1094         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1095
1096         for n, p in [("input", x)] + list(model.named_parameters()):
1097             print(f"Processing {model}.{n}")
1098             data = []
1099             for t in range(y.size(1)):
1100                 sg = 0
1101                 for d in torch.randperm(y.size(2))[:8]:
1102                     sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1103                 assert not sg.isinf().any()
1104                 assert not sg.isnan().any()
1105                 data.append([t, sg.sum().item()])
1106
1107             data = torch.tensor(data)
1108             # cx, cy = data[:, 0], data[:, 1]
1109             cy = data[:, 1].sort().values
1110             cx = torch.linspace(0, 1, cy.size(0))
1111             ax.plot(
1112                 cx, cy, label=label + "." + n, linewidth=thickness
1113             )  # , color='gray', label='Input')
1114
1115     # ax.legend(frameon=False, loc="top right")
1116
1117     # Put a legend to the right of the current axis
1118     box = ax.get_position()
1119     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1120     ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1121
1122     filename = "plot.pdf"
1123     print(f"saving {filename}")
1124     fig.savefig(filename, bbox_inches="tight")
1125
1126     # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1127     # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1128     # plt.show()
1129
1130     exit(0)
1131
1132     ######################################################################
1133
1134     m = Caterpillar(
1135         dim_model=4,
1136         dim_qk=3,
1137         dim_v=7,
1138         nb_heads=1,
1139         caterpillar_length=7,
1140         caterpillar_height=3,
1141         attention_dropout=0.0,
1142     )
1143
1144     m.reset_inner_loss()
1145     x = torch.randn(1, 21 + 2 * 7, 4)
1146     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1147     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1148     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1149     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1150     print((y1 - y2).abs().max())
1151     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1152     exit(0)
1153
1154     vocabulary_size = 128
1155     x = torch.randint(vocabulary_size, (6, 1024))
1156
1157     model = MyGPT(
1158         vocabulary_size=vocabulary_size,
1159         dim_model=512,
1160         dim_keys=64,
1161         dim_hidden=2048,
1162         nb_heads=8,
1163         nb_lines=128,
1164         nb_blocks=12,
1165         dropout=0.1,
1166         causal=True,
1167     )
1168
1169     x = x.to(device)
1170     model.to(device)
1171
1172     import time, sys
1173
1174     # import torchvision.models as models
1175     # from torch.profiler import profile, record_function, ProfilerActivity
1176
1177     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1178     # with record_function("model_inference"):
1179
1180     model.eval()
1181     for i in range(3):
1182         start_time = time.perf_counter()
1183         for k in range(10):
1184             model(BracketedSequence(x))
1185         duration = time.perf_counter() - start_time
1186         print(duration)
1187         sys.stdout.flush()
1188
1189     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1190     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1191
1192     # print("##############################################################")
1193     # y2 = torch.randn_like(y1)
1194     # for s in range(x.size(1)):
1195     # z = model(BracketedSequence(x, s, 1))
1196     # y2[:, s : s + 1] = z.slice()
1197
1198     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1199
1200 ######################################################################