Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
130
131
132 def pscan_dim(A, X, Y_init, dim=-2):
133     s = X.size()
134     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
135
136     A = A.reshape(a, T, *s[dim + 1 : -1])
137     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
138
139     if Y_init is None:
140         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
141     else:
142         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
143
144     Y = pscan.pscan(A, X, Y_init).reshape(s)
145
146     return Y
147
148
149 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
150     with torch.no_grad():
151         s_A, s_X = 0, 0
152         for t in range(X.size(dim) - 1, 0, -1):
153             delta = (grad_Y[t] - s_A) / A[t].grad
154             s_A += A[t].grad * delta
155             A[t].grad = delta
156             delta = (grad_Y[t] - s_X) / X[t].grad
157             s_X += X[t].grad * delta
158             X[t].grad = delta
159
160
161 def pscan_shape(A, X, Y_init):
162     s = X.size()
163     A = A.reshape(-1, s[-2])
164     X = X.reshape(-1, s[-2], s[-1])
165
166     if Y_init is None:
167         Y_init = X.new_zeros(X.size(0), s[-1])
168     else:
169         Y_init = Y_init.reshape(-1, s[-1])
170
171     Y = pscan.pscan(A, X, Y_init).reshape(s)
172
173     return Y
174
175
176 def nsum_shape(X, Y_init):
177     s = X.size()
178     X = X.reshape(-1, s[-2], s[-1])  # ntd
179
180     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
181     result = []
182
183     for k in range(X.size(1)):
184         Y = Y + X[:, k]
185         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
186         result.append(Y)
187
188     return torch.cat(result, dim=1).reshape(s)
189
190
191 ##############################
192
193
194 class DumbRec(nn.Module):
195     def __init__(
196         self,
197         dim_model,
198         dim_qk,
199         dim_v,
200         nb_heads,
201         nb_lines,
202         attention_dropout=0.0,
203         len_max=1e5,
204         logger=print,
205         args=None,
206     ):
207         super().__init__()
208
209         def randw(*d):
210             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
211
212         self.nb_lines = nb_lines
213         self.attention_dropout = attention_dropout
214
215         self.k_star = randw(nb_lines, dim_qk)
216
217         self.w_qw = randw(nb_heads, dim_qk, dim_model)
218         self.w_qr = randw(nb_heads, dim_qk, dim_model)
219         # self.w_k = randw(nb_heads, dim_qk, dim_model)
220         self.w_v = randw(nb_heads, dim_v, dim_model)
221         self.w_o = randw(dim_v * nb_heads, dim_model)
222
223     def reset_inner_loss(self):
224         self.acc_attention = 0
225         self.acc_nb = 0
226
227     def get_inner_loss(self):
228         warnings.warn("l2 regularization", RuntimeWarning)
229         return (self.acc_attention / self.acc_nb).pow(2).sum()
230         # return torch.tensor([0], device=self.w_qw.device)
231
232     def forward(self, bs):
233         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
234
235         if bs.init_cache:
236             self.rec_v = x_q.new_zeros(
237                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
238             )
239             # self.rec_k = x_q.new_zeros(
240             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
241             # )
242             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
243
244         ######################################################################
245         # Prepare the keys
246
247         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
248
249         warnings.warn("rotating key barrel", RuntimeWarning)
250         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
251         t_barrel = torch.arange(t0, t1, device=k_star.device)
252         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
253         l_barrel = (
254             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
255         ) % k_star.size(0)
256         k_star = k_star[l_barrel, t_barrel]
257
258         ######################################################################
259         # Compute the recurrent state
260
261         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
262
263         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
264         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
265
266         aw = torch.einsum(
267             "nhtd,ltd->nhlt",
268             qw,
269             k_star,
270         ) / math.sqrt(self.w_qw.size(1))
271
272         aw = aw.softmax(dim=2)  # nhlt
273
274         if self.train:
275             self.acc_attention += aw.sum(dim=(0, 1, 3))
276             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
277
278         aw = F.dropout(aw, self.attention_dropout, self.training)
279
280         A = 1 - aw.sum(dim=1)  # nlt
281
282         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
283         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
284
285         if t0 == 0:
286             V0 = None
287             # K0 = None
288         else:
289             V0 = self.rec_v[:, :, t0 - 1]
290             # K0 = self.rec_k[:, :, t0 - 1]
291
292         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
293         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
294
295         ######################################################################
296         # compute the readout
297
298         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
299
300         ar = torch.einsum(
301             "nhtd,ld->nhlt",
302             qr,
303             # self.rec_k[:, :, t0:t1],
304             self.k_star,
305         ) / math.sqrt(self.w_qr.size(1))
306
307         ar = ar.softmax(dim=2)  # nhlt
308
309         ar = F.dropout(ar, self.attention_dropout, self.training)
310
311         y = torch.einsum(
312             "nhlt,nltd->nthd",
313             ar,
314             self.rec_v[:, :, t0:t1],
315         ).flatten(2)
316
317         self.cache_y[:, t0:t1] = y @ self.w_o
318
319         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
320
321
322 ##############################
323
324
325 class KVRec(nn.Module):
326     def __init__(
327         self,
328         dim_model,
329         dim_qk,
330         dim_v,
331         nb_heads,
332         nb_lines,
333         attention_dropout=0.0,
334         len_max=1e5,
335         logger=print,
336         args=None,
337     ):
338         super().__init__()
339
340         def randw(*d):
341             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
342
343         self.nb_lines = nb_lines
344         self.attention_dropout = attention_dropout
345
346         self.k_star = randw(nb_lines, dim_qk)
347
348         self.w_qw = randw(nb_heads, dim_qk, dim_model)
349         self.w_qr = randw(nb_heads, dim_qk, dim_model)
350         self.w_k = randw(nb_heads, dim_qk, dim_model)
351         self.w_v = randw(nb_heads, dim_v, dim_model)
352         self.w_o = randw(dim_v * nb_heads, dim_model)
353
354     def reset_inner_loss(self):
355         self.acc_attention = 0
356         self.acc_nb = 0
357
358     def get_inner_loss(self):
359         warnings.warn("l2 regularization", RuntimeWarning)
360         return (self.acc_attention / self.acc_nb).pow(2).sum()
361         # return torch.tensor([0], device=self.w_qw.device)
362         # warnings.warn("side regularization", RuntimeWarning)
363         # return (
364         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
365         # )
366         # return torch.tensor([0], device=self.w_qw.device)
367
368     def forward(self, bs):
369         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
370
371         if bs.init_cache:
372             self.rec_v = x_q.new_zeros(
373                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
374             )
375             self.rec_k = x_q.new_zeros(
376                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
377             )
378             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
379
380         ######################################################################
381         # Prepare the keys
382
383         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
384
385         warnings.warn("rotating key barrel", RuntimeWarning)
386         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
387         t_barrel = torch.arange(t0, t1, device=k_star.device)
388         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
389         l_barrel = (
390             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
391         ) % k_star.size(0)
392         k_star = k_star[l_barrel, t_barrel]
393
394         ######################################################################
395         # Compute the recurrent state
396
397         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
398
399         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
400         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
401
402         aw = torch.einsum(
403             "nhtd,ltd->nhlt",
404             qw,
405             k_star,
406         ) / math.sqrt(self.w_qw.size(1))
407
408         aw = aw.softmax(dim=2)  # nhlt
409
410         if self.train:
411             # We want all the memory lines to be used similarly
412             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
413             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
414
415         aw = F.dropout(aw, self.attention_dropout, self.training)
416
417         A = 1 - aw.sum(dim=1)  # nlt
418
419         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
420         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
421
422         if t0 == 0:
423             V0 = None
424             K0 = None
425         else:
426             V0 = self.rec_v[:, :, t0 - 1]
427             K0 = self.rec_k[:, :, t0 - 1]
428
429         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
430         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
431
432         ######################################################################
433         # compute the readout
434
435         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
436
437         ar = torch.einsum(
438             "nhtd,nltd->nhlt",
439             qr,
440             self.rec_k[:, :, t0:t1],
441         ) / math.sqrt(self.w_qr.size(1))
442
443         ar = ar.softmax(dim=2)  # nhlt
444
445         ar = F.dropout(ar, self.attention_dropout, self.training)
446
447         y = torch.einsum(
448             "nhlt,nltd->nthd",
449             ar,
450             self.rec_v[:, :, t0:t1],
451         ).flatten(2)
452
453         self.cache_y[:, t0:t1] = y @ self.w_o
454
455         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
456
457
458 ##############################
459
460
461 # Returns a tensor with an additional index at rank win_dim, that move
462 # along the same dimension as dim, on a domain {0...win_size-1}, and
463 # dim is restricted on a domain reduced by win_size-1 values.
464
465
466 def moving_window(x, dim, win_dim, win_size):
467     size, stride = x.size(), x.stride()
468     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
469     size = size[:win_dim] + (win_size,) + size[win_dim:]
470     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
471
472     return x.as_strided(size=size, stride=stride)
473
474
475 ##############################
476
477
478 class Caterpillar(nn.Module):
479     def __init__(
480         self,
481         dim_model,
482         dim_qk,
483         dim_v,
484         nb_heads,
485         caterpillar_length,
486         caterpillar_height,
487         attention_dropout=0.0,
488         len_max=1e5,
489         logger=print,
490         args=None,
491     ):
492         super().__init__()
493
494         warnings.warn("Caterpillar", RuntimeWarning)
495
496         def randw(*d, factor=1):
497             return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
498
499         self.caterpillar_length = caterpillar_length
500         self.caterpillar_height = caterpillar_height
501         self.attention_dropout = attention_dropout
502
503         self.gate_dropout_proba = args.gate_dropout_proba
504         self.gate_dropout_sync = args.gate_dropout_sync
505         self.gate_dropout_replace = args.gate_dropout_replace
506
507         ######################################################################
508
509         self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0)
510         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
511
512         self.w_K = randw(nb_heads, dim_qk, dim_model)
513         self.w_V = randw(nb_heads, dim_v, dim_model, factor=1)
514         self.w_Q = randw(nb_heads, dim_qk, dim_model)
515         self.w_O = randw(dim_v * nb_heads, dim_model)
516
517         self.init_K_rec = randw(
518             caterpillar_height,
519             caterpillar_length,
520             dim_qk,
521         )
522         self.init_V_rec = randw(
523             caterpillar_height,
524             caterpillar_length,
525             dim_v,
526         )
527
528     # def reset_inner_loss(self):
529     # self.acc_attention = 0
530     # self.acc_nb = 0
531
532     # def get_inner_loss(self):
533     # warnings.warn("l2 regularization", RuntimeWarning)
534     # return (self.acc_attention / self.acc_nb).pow(2).sum()
535     # return torch.tensor([0], device=self.w_Q.device)
536
537     def forward(self, bs):
538         # Dimensions to make the source a bit clearer, that's needed
539
540         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
541
542         N = bs.x.size(0)
543         T = bs.x.size(1)
544         H = self.w_V.size(0)
545         DV = self.w_V.size(1)
546         DK = self.w_K.size(1)
547         DM = self.w_O.size(1)
548         R = self.caterpillar_height
549         L = self.caterpillar_length
550
551         assert (
552             t0 >= L and (t1 - t0) % L == 0
553         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
554
555         # We cache values to deal efficiently with auto-regression
556
557         if bs.init_cache:
558             self.rec_V = X.new_zeros(N, R, T, DV)
559             self.rec_K = X.new_zeros(N, R, T, DK)
560             # We start the recurrent sequences with optimizable
561             # initial values. No idea if it helps.
562             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
563             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
564
565             self.cache_Y = X.new_zeros(N, T, DM)
566
567         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
568         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
569
570         ######################################################################
571         # Compute the recurrent state
572
573         # This is the Gating sequence that modulates the storing of
574         # the new key and value in the R pairs of the current
575         # stack. There are R independent gating values, which means
576         # that the current K/V may be stored in multiple pairs of the
577         # recurrent state, or not at all.
578
579         G = (
580             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
581         ).sigmoid()
582
583         # Clip the gating to avoid values greater than 1 when several
584         # heads hit the same row
585
586         # G = G / G.sum(1, keepdim=True).clamp(min=1)
587
588         H = (1 - G).log().sum(1, keepdim=True).exp()
589
590         ######################################################################
591
592         def recurrence(G, V, K):
593             # We prepare the arguments for the parallel scan
594
595             A = H
596
597             gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V)
598             gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K)
599
600             # We start from cached values, which matters in inference
601
602             init_rec_V = self.rec_V[:, :, t0 - L : t0]
603             init_rec_K = self.rec_K[:, :, t0 - L : t0]
604
605             # Here there is a trick: Since the stack at position t is
606             # computed by updating that at position t-L, the parallel
607             # scan operates with a period of L. To do so we split the
608             # sequence indexing in two axes, the second of size L, and
609             # run the parallel scan using the first as the sequence index.
610
611             A = A.unflatten(2, (-1, L))
612             gated_V = gated_V.unflatten(2, (-1, L))
613             gated_K = gated_K.unflatten(2, (-1, L))
614
615             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
616             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
617
618             return next_V, next_K
619
620         #################################################################
621
622         next_V, next_K = recurrence(G, V, K)
623
624         if self.training and self.gate_dropout_proba > 0.0:
625             # G is NxHxRxT where r is the caterpillar's row.
626
627             warnings.warn("gate dropout", RuntimeWarning)
628
629             if self.gate_dropout_sync:
630                 shape_kill = (N, 1, 1)
631             else:
632                 shape_kill = (N, H, R)
633
634             # Pick a point in each of the NxHxR timeline and set this
635             # entry and the following to 1
636             kill = (
637                 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
638                 == 0
639             ).cumsum(dim=3)
640
641             # Keep these mask for only some of the NxHxR
642             kill = kill * (
643                 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
644             )
645
646             # The coefficient to keep are the complementary
647             mask = 1 - kill
648
649             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
650
651             if self.gate_dropout_replace:
652                 next_V = next_V.detach()
653                 next_K = next_K.detach()
654
655             warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
656
657             next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
658                 1 - self.gate_dropout_proba
659             )
660             next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
661                 1 - self.gate_dropout_proba
662             )
663
664         self.rec_V[:, :, t0:t1] = next_V
665         self.rec_K[:, :, t0:t1] = next_K
666
667         ######################################################################
668         # compute the readout
669
670         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
671
672         # We build tensors NxHxTxRxL where N is the sample index, H
673         # the head, T the time, R the row in the caterpillar, and L
674         # the column in the caterpillar
675
676         windowed_V = moving_window(
677             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
678         )
679
680         windowed_K = moving_window(
681             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
682         )
683
684         # We have an attention score for each of the RxL values
685
686         ar = torch.einsum(
687             "nhtd,nrtld->nhtrl",
688             Q,
689             windowed_K,
690         ) / math.sqrt(DK)
691
692         # softmax can operate only on one dimension, hence the
693         # flattening
694
695         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
696
697         ar = F.dropout(ar, self.attention_dropout, self.training)
698
699         # Compute the output for each head, flatten to concatenate
700
701         Y = torch.einsum(
702             "nhtfl,nftld->nthd",
703             ar,
704             windowed_V,
705         ).flatten(2)
706
707         # Compute the final output
708
709         self.cache_Y[:, t0:t1] = Y @ self.w_O
710
711         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
712
713
714 ##############################
715
716
717 class QKVAttention(nn.Module):
718     def __init__(
719         self,
720         dim_model,
721         dim_qk,
722         dim_v,
723         nb_heads=1,
724         causal=False,
725         attention_dropout=0.0,
726         logger=print,
727         args=None,
728     ):
729         super().__init__()
730
731         def randw(*d):
732             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
733
734         self.causal = causal
735         self.attention_dropout = attention_dropout
736         self.record_attention = False
737
738         self.w_q = randw(nb_heads, dim_qk, dim_model)
739         self.w_k = randw(nb_heads, dim_qk, dim_model)
740         self.w_v = randw(nb_heads, dim_v, dim_model)
741         self.w_o = randw(dim_v * nb_heads, dim_model)
742
743     def forward(self, bs):
744         x_q = bs.x
745
746         assert (
747             self.causal or bs.complete()
748         ), "Partial evaluation is only possible for causal models"
749
750         if bs.init_cache:
751             self.cache_k = x_q.new_zeros(
752                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
753             )
754             self.cache_v = x_q.new_zeros(
755                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
756             )
757             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
758
759         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
760
761         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
762             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
763         )
764         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
765             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
766         )
767
768         a = torch.einsum(
769             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
770         ) / math.sqrt(self.w_q.size(1))
771
772         if self.causal:
773             if bs.init_cache:
774                 self.cache_attzero = (
775                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
776                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
777                 )
778             a = a.masked_fill(
779                 self.cache_attzero[
780                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
781                 ],
782                 float("-inf"),
783             )
784
785         a = a.softmax(dim=3)
786
787         if self.record_attention:
788             self.a = a
789
790         a = F.dropout(a, self.attention_dropout, self.training)
791
792         y = torch.einsum(
793             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
794         ).flatten(2)
795
796         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
797
798         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
799
800
801 ##############################
802
803
804 class MyGPT(nn.Module):
805     def __init__(
806         self,
807         vocabulary_size,
808         dim_model,
809         dim_keys,
810         dim_hidden,
811         nb_heads,
812         nb_blocks,
813         nb_lines=None,
814         caterpillar_height=None,
815         causal=False,
816         dropout=0.0,
817         len_max=1e5,
818         attention_layer="caterpillar",
819         logger=print,
820         args=None,
821     ):
822         super().__init__()
823
824         assert attention_layer in {
825             "mha",
826             "dumbrec",
827             "kvrec",
828             "caterpillar",
829         }, f"Unknown attention operator {attention_layer}."
830
831         if attention_layer == "caterpillar":
832             assert nb_lines % caterpillar_height == 0
833             self.caterpillar_length = nb_lines // caterpillar_height
834             self.caterpillar_height = caterpillar_height
835         else:
836             self.caterpillar_length = -1
837             self.caterpillar_height = -1
838
839         assert dim_model % nb_heads == 0
840
841         self.embedding = nn.Sequential(
842             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
843             AddPositionalEncoding(len_max),
844         )
845
846         trunk_blocks = []
847
848         def attlayer():
849             if attention_layer == "mha":
850                 return QKVAttention(
851                     dim_model=dim_model,
852                     dim_qk=dim_keys,
853                     dim_v=dim_model // nb_heads,
854                     nb_heads=nb_heads,
855                     causal=causal,
856                     attention_dropout=dropout,
857                     logger=logger,
858                     args=args,
859                 )
860             elif attention_layer == "dumbrec":
861                 return DumbRec(
862                     dim_model=dim_model,
863                     dim_qk=dim_keys,
864                     dim_v=dim_model // nb_heads,
865                     nb_heads=nb_heads,
866                     nb_lines=nb_lines,
867                     attention_dropout=dropout,
868                     logger=logger,
869                     args=args,
870                 )
871             elif attention_layer == "kvrec":
872                 return KVRec(
873                     dim_model=dim_model,
874                     dim_qk=dim_keys,
875                     dim_v=dim_model // nb_heads,
876                     nb_heads=nb_heads,
877                     nb_lines=nb_lines,
878                     attention_dropout=dropout,
879                     logger=logger,
880                     args=args,
881                 )
882             elif attention_layer == "caterpillar":
883                 return Caterpillar(
884                     dim_model=dim_model,
885                     dim_qk=dim_keys,
886                     dim_v=dim_model // nb_heads,
887                     nb_heads=nb_heads,
888                     caterpillar_length=self.caterpillar_length,
889                     caterpillar_height=self.caterpillar_height,
890                     attention_dropout=dropout,
891                     logger=logger,
892                     args=args,
893                 )
894             else:
895                 raise ValueError(f"Unknown attention type {attention_layer}.")
896
897         for b in range(nb_blocks):
898             trunk_blocks += [
899                 WithResidual(
900                     CacheWrapper(nn.LayerNorm((dim_model,))),
901                     attlayer(),
902                 ),
903                 WithResidual(
904                     CacheWrapper(
905                         nn.LayerNorm((dim_model,)),
906                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
907                         nn.ReLU(),
908                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
909                         nn.Dropout(dropout),
910                     ),
911                 ),
912             ]
913
914         self.trunk = nn.Sequential(*trunk_blocks)
915
916         self.readout = CacheWrapper(
917             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
918         )
919
920         with torch.no_grad():
921             for m in self.modules():
922                 if isinstance(m, nn.Embedding):
923                     m.weight.normal_(mean=0, std=2e-2)
924                 elif isinstance(m, nn.LayerNorm):
925                     m.bias.zero_()
926                     m.weight.fill_(1.0)
927
928         self.reset_inner_loss()
929
930     def forward(self, bs):
931         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
932
933         # To make the code simpler in the Caterpillar layer, we pad
934         # here. It's unclear if/how much it hurts computationaly by
935         # increasing the sequence length for the other layers
936
937         if self.caterpillar_length > 0:
938             original_nb = bs.nb
939             if bs.nb % self.caterpillar_length > 0:
940                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
941
942             bs = BracketedSequence(
943                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
944                 bs.first + self.caterpillar_length,
945                 bs.nb,
946                 bs.init_cache,
947             )
948
949         bs = self.embedding(bs)
950         bs = self.trunk(bs)
951         bs = self.readout(bs)
952
953         if self.caterpillar_length > 0:
954             bs = BracketedSequence(
955                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
956                 bs.first - self.caterpillar_length,
957                 original_nb,
958                 bs.init_cache,
959             )
960
961         return bs
962
963     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
964     # 1s where tokens should be generated. The others are kept
965     # unchanged.
966
967     def masked_inplace_autoregression(
968         self,
969         input_src,
970         ar_mask_src,
971         forbidden_tokens=None,
972         deterministic_synthesis=False,
973     ):
974         input = input_src.to(self.readout.f.weight.device)
975         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
976         to_generate = (ar_mask.sum(0) > 0).nonzero()
977         if to_generate.min() > 0:
978             self(
979                 BracketedSequence(input, 0, to_generate.min(), True)
980             )  # Needed to initialize the model's cache
981         for s in range(to_generate.min(), to_generate.max() + 1):
982             output = self(BracketedSequence(input, s, 1, s == 0)).x
983             logits = output[:, s]
984             if forbidden_tokens is not None:
985                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
986             if deterministic_synthesis:
987                 t_next = logits.argmax(1)
988             else:
989                 dist = torch.distributions.categorical.Categorical(logits=logits)
990                 t_next = dist.sample()
991             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
992
993         input_src.copy_(input)
994
995     def reset_inner_loss(self):
996         for m in self.modules():
997             if m is not self and hasattr(m, "reset_inner_loss"):
998                 m.reset_inner_loss()
999
1000     def get_inner_loss(self):
1001         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1002         for m in self.modules():
1003             if m is not self and hasattr(m, "get_inner_loss"):
1004                 l += m.get_inner_loss()
1005         return l
1006
1007     def record_attention(self, v=True):
1008         for m in self.modules():
1009             if isinstance(m, QKVAttention):
1010                 m.record_attention = v
1011
1012     def retrieve_attention(self):
1013         a = []
1014         for m in self.modules():
1015             if isinstance(m, QKVAttention):
1016                 a.append(m.a)
1017         return a
1018
1019
1020 ######################################################################
1021
1022 if __name__ == "__main__":
1023     import argparse
1024
1025     import numpy as np
1026     import matplotlib.pyplot as plt
1027     import matplotlib.collections as mc
1028
1029     args = argparse.Namespace(
1030         gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1031     )
1032
1033     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1034
1035     dim_model, dim_keys, nb_heads = 512, 64, 1
1036     dropout = 0.1
1037
1038     caterpillar = Caterpillar(
1039         dim_model=dim_model,
1040         dim_qk=dim_keys,
1041         dim_v=dim_model // nb_heads,
1042         nb_heads=nb_heads,
1043         caterpillar_length=16,
1044         caterpillar_height=32,
1045         attention_dropout=dropout,
1046         args=args,
1047     ).to(device)
1048
1049     qkv = QKVAttention(
1050         dim_model=dim_model,
1051         dim_qk=dim_keys,
1052         dim_v=dim_model // nb_heads,
1053         nb_heads=nb_heads,
1054         causal=True,
1055         attention_dropout=dropout,
1056         args=args,
1057     ).to(device)
1058
1059     linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1060
1061     x = torch.randn(1, 256, dim_model)
1062
1063     x = x.to(device)
1064     x.requires_grad_()
1065
1066     ######################################################################
1067
1068     fig = plt.figure()
1069     fig.set_figheight(6)
1070     fig.set_figwidth(8)
1071
1072     ax = fig.add_subplot(1, 1, 1)
1073
1074     # ax.set_xlim(-1.5, 1.5)
1075     # ax.set_ylim(-1.5, 1.5)
1076     # ax.set(aspect=1)
1077     # ax.spines.right.set_visible(False)
1078     # ax.spines.top.set_visible(False)
1079
1080     # dt = 0.01
1081     # t = np.arange(dt, 20.0, dt)
1082     # ax.semilogx(t, np.exp(-t / 5.0))
1083     # ax.grid()
1084
1085     ######################################################################
1086
1087     for label, model in [
1088         # ("nn.Linear", linear),
1089         ("mygpy.QKVAttention", qkv),
1090         ("mygpt.Caterpillar", caterpillar),
1091     ]:
1092         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1093
1094         data = []
1095         for t in range(y.size(1)):
1096             for d in torch.randperm(y.size(2))[:8]:
1097                 g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0]
1098                 sg = g.pow(2).sum().item()
1099                 # sg = 0
1100                 # for p in model.parameters():
1101                 # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1102                 # sg = sg + g.pow(2).sum().item()
1103                 data.append([t, sg])
1104
1105         data = torch.tensor(data)
1106         ax.scatter(
1107             data[:, 0], data[:, 1], s=1, label=label
1108         )  # , color='gray', label='Input')
1109
1110     # ax.legend(frameon=False, loc="top right")
1111
1112     # Put a legend to the right of the current axis
1113     box = ax.get_position()
1114     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1115     ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1116
1117     filename = "plot.pdf"
1118     print(f"saving {filename}")
1119     fig.savefig(filename, bbox_inches="tight")
1120
1121     # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1122     # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1123     # plt.show()
1124
1125     exit(0)
1126
1127     ######################################################################
1128
1129     m = Caterpillar(
1130         dim_model=4,
1131         dim_qk=3,
1132         dim_v=7,
1133         nb_heads=1,
1134         caterpillar_length=7,
1135         caterpillar_height=3,
1136         attention_dropout=0.0,
1137     )
1138
1139     m.reset_inner_loss()
1140     x = torch.randn(1, 21 + 2 * 7, 4)
1141     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1142     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1143     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1144     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1145     print((y1 - y2).abs().max())
1146     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1147     exit(0)
1148
1149     vocabulary_size = 128
1150     x = torch.randint(vocabulary_size, (6, 1024))
1151
1152     model = MyGPT(
1153         vocabulary_size=vocabulary_size,
1154         dim_model=512,
1155         dim_keys=64,
1156         dim_hidden=2048,
1157         nb_heads=8,
1158         nb_lines=128,
1159         nb_blocks=12,
1160         dropout=0.1,
1161         causal=True,
1162     )
1163
1164     x = x.to(device)
1165     model.to(device)
1166
1167     import time, sys
1168
1169     # import torchvision.models as models
1170     # from torch.profiler import profile, record_function, ProfilerActivity
1171
1172     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1173     # with record_function("model_inference"):
1174
1175     model.eval()
1176     for i in range(3):
1177         start_time = time.perf_counter()
1178         for k in range(10):
1179             model(BracketedSequence(x))
1180         duration = time.perf_counter() - start_time
1181         print(duration)
1182         sys.stdout.flush()
1183
1184     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1185     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1186
1187     # print("##############################################################")
1188     # y2 = torch.randn_like(y1)
1189     # for s in range(x.size(1)):
1190     # z = model(BracketedSequence(x, s, 1))
1191     # y2[:, s : s + 1] = z.slice()
1192
1193     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1194
1195 ######################################################################