Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 from blanket import blanket
25
26 # import memload
27
28 ######################################################################
29
30 # A BracketedSequence is a BxTx... tensor with a first and a nb time
31 # steps to compute.
32
33 # Modules able to process it expect that they will have to process a
34 # first bracket starting at t=0, followed by a succession of brackets
35 # that move forward in time, do not overlap, and cover the axis T with
36 # no holes.
37 #
38 # Although it is more general, for a classical prompt-conditioned
39 # auto-regressive process it will be a first bracket starting at 0 and
40 # of arbitrary length for the "prompt", followed by brackets of length
41 # 1 for the successive tokens.
42 #
43 # Modules able to process brackets may implement a cache that is
44 # resetted when init_cache is True
45
46
47 class BracketedSequence:
48     def __init__(self, x, first=None, nb=None, init_cache=None):
49         self.x = x
50         assert (first is None and nb is None and init_cache is None) or (
51             first is not None and nb is not None and init_cache is not None
52         )
53
54         self.first = 0 if first is None else first
55         self.nb = x.size(1) if nb is None else nb
56         self.init_cache = True if init_cache is None else init_cache
57
58     def slice(self):
59         return self.x[:, self.first : self.first + self.nb]
60
61     def complete(self):
62         return self.first == 0 and self.nb == self.x.size(1)
63
64
65 ######################################################################
66
67
68 class CacheWrapper(nn.Module):
69     def __init__(self, *f):
70         super().__init__()
71         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
72
73     def forward(self, bs):
74         if bs.init_cache:
75             y = self.f(bs.slice())
76             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
77             self.cache_y[:, bs.first : bs.first + bs.nb] = y
78         else:
79             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
80             assert bs.first + bs.nb <= self.cache_y.size(1)
81             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
82
83         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
84
85
86 ##############################
87
88
89 class WithResidual(nn.Module):
90     def __init__(self, *f):
91         super().__init__()
92         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
93
94     def forward(self, bs):
95         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
96
97
98 ##############################
99
100
101 class AddPositionalEncoding(nn.Module):
102     def __init__(self, len_max):
103         super().__init__()
104         self.len_max = len_max
105
106     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
107
108     def forward(self, bs):
109         if bs.init_cache:
110             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
111                 :, None
112             ]
113             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
114                 None, :
115             ]
116             k = j % 2
117             self.pe = torch.sin(
118                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
119             )
120             self.cache_y = bs.x.new(bs.x.size())
121
122         self.cache_y[:, bs.first : bs.first + bs.nb] = (
123             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
124         )
125
126         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
127
128
129 import pscan
130
131 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
132
133
134 def pscan_dim(A, X, Y_init, dim=-2):
135     s = X.size()
136     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
137
138     A = A.reshape(a, T, *s[dim + 1 : -1])
139     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
140
141     if Y_init is None:
142         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
143     else:
144         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
145
146     Y = pscan.pscan(A, X, Y_init).reshape(s)
147
148     return Y
149
150
151 def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
152     with torch.no_grad():
153         s_A, s_X = 0, 0
154         for t in range(X.size(dim) - 1, 0, -1):
155             delta = (grad_Y[t] - s_A) / A[t].grad
156             s_A += A[t].grad * delta
157             A[t].grad = delta
158             delta = (grad_Y[t] - s_X) / X[t].grad
159             s_X += X[t].grad * delta
160             X[t].grad = delta
161
162
163 def pscan_shape(A, X, Y_init):
164     s = X.size()
165     A = A.reshape(-1, s[-2])
166     X = X.reshape(-1, s[-2], s[-1])
167
168     if Y_init is None:
169         Y_init = X.new_zeros(X.size(0), s[-1])
170     else:
171         Y_init = Y_init.reshape(-1, s[-1])
172
173     Y = pscan.pscan(A, X, Y_init).reshape(s)
174
175     return Y
176
177
178 def nsum_shape(X, Y_init):
179     s = X.size()
180     X = X.reshape(-1, s[-2], s[-1])  # ntd
181
182     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
183     result = []
184
185     for k in range(X.size(1)):
186         Y = Y + X[:, k]
187         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
188         result.append(Y)
189
190     return torch.cat(result, dim=1).reshape(s)
191
192
193 ##############################
194
195
196 class DumbRec(nn.Module):
197     def __init__(
198         self,
199         dim_model,
200         dim_qk,
201         dim_v,
202         nb_heads,
203         nb_lines,
204         attention_dropout=0.0,
205         len_max=1e5,
206         logger=print,
207         args=None,
208     ):
209         super().__init__()
210
211         def randw(*d):
212             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
213
214         self.nb_lines = nb_lines
215         self.attention_dropout = attention_dropout
216
217         self.k_star = randw(nb_lines, dim_qk)
218
219         self.w_qw = randw(nb_heads, dim_qk, dim_model)
220         self.w_qr = randw(nb_heads, dim_qk, dim_model)
221         # self.w_k = randw(nb_heads, dim_qk, dim_model)
222         self.w_v = randw(nb_heads, dim_v, dim_model)
223         self.w_o = randw(dim_v * nb_heads, dim_model)
224
225     def reset_inner_loss(self):
226         self.acc_attention = 0
227         self.acc_nb = 0
228
229     def get_inner_loss(self):
230         warnings.warn("l2 regularization", RuntimeWarning)
231         return (self.acc_attention / self.acc_nb).pow(2).sum()
232         # return torch.tensor([0], device=self.w_qw.device)
233
234     def forward(self, bs):
235         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
236
237         if bs.init_cache:
238             self.rec_v = x_q.new_zeros(
239                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
240             )
241             # self.rec_k = x_q.new_zeros(
242             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
243             # )
244             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
245
246         ######################################################################
247         # Prepare the keys
248
249         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
250
251         warnings.warn("rotating key barrel", RuntimeWarning)
252         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
253         t_barrel = torch.arange(t0, t1, device=k_star.device)
254         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
255         l_barrel = (
256             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
257         ) % k_star.size(0)
258         k_star = k_star[l_barrel, t_barrel]
259
260         ######################################################################
261         # Compute the recurrent state
262
263         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
264
265         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
266         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
267
268         aw = torch.einsum(
269             "nhtd,ltd->nhlt",
270             qw,
271             k_star,
272         ) / math.sqrt(self.w_qw.size(1))
273
274         aw = aw.softmax(dim=2)  # nhlt
275
276         if self.train:
277             self.acc_attention += aw.sum(dim=(0, 1, 3))
278             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
279
280         aw = F.dropout(aw, self.attention_dropout, self.training)
281
282         A = 1 - aw.sum(dim=1)  # nlt
283
284         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
285         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
286
287         if t0 == 0:
288             V0 = None
289             # K0 = None
290         else:
291             V0 = self.rec_v[:, :, t0 - 1]
292             # K0 = self.rec_k[:, :, t0 - 1]
293
294         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
295         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
296
297         ######################################################################
298         # compute the readout
299
300         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
301
302         ar = torch.einsum(
303             "nhtd,ld->nhlt",
304             qr,
305             # self.rec_k[:, :, t0:t1],
306             self.k_star,
307         ) / math.sqrt(self.w_qr.size(1))
308
309         ar = ar.softmax(dim=2)  # nhlt
310
311         ar = F.dropout(ar, self.attention_dropout, self.training)
312
313         y = torch.einsum(
314             "nhlt,nltd->nthd",
315             ar,
316             self.rec_v[:, :, t0:t1],
317         ).flatten(2)
318
319         self.cache_y[:, t0:t1] = y @ self.w_o
320
321         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
322
323
324 ##############################
325
326
327 class KVRec(nn.Module):
328     def __init__(
329         self,
330         dim_model,
331         dim_qk,
332         dim_v,
333         nb_heads,
334         nb_lines,
335         attention_dropout=0.0,
336         len_max=1e5,
337         logger=print,
338         args=None,
339     ):
340         super().__init__()
341
342         def randw(*d):
343             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
344
345         self.nb_lines = nb_lines
346         self.attention_dropout = attention_dropout
347
348         self.k_star = randw(nb_lines, dim_qk)
349
350         self.w_qw = randw(nb_heads, dim_qk, dim_model)
351         self.w_qr = randw(nb_heads, dim_qk, dim_model)
352         self.w_k = randw(nb_heads, dim_qk, dim_model)
353         self.w_v = randw(nb_heads, dim_v, dim_model)
354         self.w_o = randw(dim_v * nb_heads, dim_model)
355
356     def reset_inner_loss(self):
357         self.acc_attention = 0
358         self.acc_nb = 0
359
360     def get_inner_loss(self):
361         warnings.warn("l2 regularization", RuntimeWarning)
362         return (self.acc_attention / self.acc_nb).pow(2).sum()
363         # return torch.tensor([0], device=self.w_qw.device)
364         # warnings.warn("side regularization", RuntimeWarning)
365         # return (
366         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
367         # )
368         # return torch.tensor([0], device=self.w_qw.device)
369
370     def forward(self, bs):
371         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
372
373         if bs.init_cache:
374             self.rec_v = x_q.new_zeros(
375                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
376             )
377             self.rec_k = x_q.new_zeros(
378                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
379             )
380             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
381
382         ######################################################################
383         # Prepare the keys
384
385         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
386
387         warnings.warn("rotating key barrel", RuntimeWarning)
388         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
389         t_barrel = torch.arange(t0, t1, device=k_star.device)
390         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
391         l_barrel = (
392             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
393         ) % k_star.size(0)
394         k_star = k_star[l_barrel, t_barrel]
395
396         ######################################################################
397         # Compute the recurrent state
398
399         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
400
401         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
402         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
403
404         aw = torch.einsum(
405             "nhtd,ltd->nhlt",
406             qw,
407             k_star,
408         ) / math.sqrt(self.w_qw.size(1))
409
410         aw = aw.softmax(dim=2)  # nhlt
411
412         if self.train:
413             # We want all the memory lines to be used similarly
414             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
415             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
416
417         aw = F.dropout(aw, self.attention_dropout, self.training)
418
419         A = 1 - aw.sum(dim=1)  # nlt
420
421         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
422         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
423
424         if t0 == 0:
425             V0 = None
426             K0 = None
427         else:
428             V0 = self.rec_v[:, :, t0 - 1]
429             K0 = self.rec_k[:, :, t0 - 1]
430
431         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
432         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
433
434         ######################################################################
435         # compute the readout
436
437         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
438
439         ar = torch.einsum(
440             "nhtd,nltd->nhlt",
441             qr,
442             self.rec_k[:, :, t0:t1],
443         ) / math.sqrt(self.w_qr.size(1))
444
445         ar = ar.softmax(dim=2)  # nhlt
446
447         ar = F.dropout(ar, self.attention_dropout, self.training)
448
449         y = torch.einsum(
450             "nhlt,nltd->nthd",
451             ar,
452             self.rec_v[:, :, t0:t1],
453         ).flatten(2)
454
455         self.cache_y[:, t0:t1] = y @ self.w_o
456
457         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
458
459
460 ##############################
461
462
463 # Returns a tensor with an additional index at rank win_dim, that move
464 # along the same dimension as dim, on a domain {0...win_size-1}, and
465 # dim is restricted on a domain reduced by win_size-1 values.
466
467
468 def moving_window(x, dim, win_dim, win_size):
469     size, stride = x.size(), x.stride()
470     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
471     size = size[:win_dim] + (win_size,) + size[win_dim:]
472     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
473
474     return x.as_strided(size=size, stride=stride)
475
476
477 ##############################
478
479
480 class Caterpillar(nn.Module):
481     def __init__(
482         self,
483         dim_model,
484         dim_qk,
485         dim_v,
486         nb_heads,
487         caterpillar_length,
488         caterpillar_height,
489         attention_dropout=0.0,
490         len_max=1e5,
491         logger=print,
492         args=None,
493     ):
494         super().__init__()
495
496         warnings.warn("Caterpillar", RuntimeWarning)
497
498         def randw(*d, factor=1):
499             return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
500
501         self.caterpillar_length = caterpillar_length
502         self.caterpillar_height = caterpillar_height
503         self.attention_dropout = attention_dropout
504
505         self.gate_dropout_proba = args.gate_dropout_proba
506         self.gate_dropout_sync = args.gate_dropout_sync
507         self.gate_dropout_replace = args.gate_dropout_replace
508
509         ######################################################################
510
511         self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3)
512         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
513
514         self.w_K = randw(nb_heads, dim_qk, dim_model)
515         self.w_V = randw(nb_heads, dim_v, dim_model)
516         self.w_Q = randw(nb_heads, dim_qk, dim_model)
517         self.w_O = randw(dim_v * nb_heads, dim_model)
518
519         self.init_K_rec = randw(
520             caterpillar_height,
521             caterpillar_length,
522             dim_qk,
523         )
524         self.init_V_rec = randw(
525             caterpillar_height,
526             caterpillar_length,
527             dim_v,
528         )
529
530     # def reset_inner_loss(self):
531     # self.acc_attention = 0
532     # self.acc_nb = 0
533
534     # def get_inner_loss(self):
535     # warnings.warn("l2 regularization", RuntimeWarning)
536     # return (self.acc_attention / self.acc_nb).pow(2).sum()
537     # return torch.tensor([0], device=self.w_Q.device)
538
539     def forward(self, bs):
540         # Dimensions to make the source a bit clearer, that's needed
541
542         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
543
544         N = bs.x.size(0)
545         T = bs.x.size(1)
546         H = self.w_V.size(0)
547         DV = self.w_V.size(1)
548         DK = self.w_K.size(1)
549         DM = self.w_O.size(1)
550         R = self.caterpillar_height
551         L = self.caterpillar_length
552
553         assert (
554             t0 >= L and (t1 - t0) % L == 0
555         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
556
557         # We cache values to deal efficiently with auto-regression
558
559         if bs.init_cache:
560             self.rec_V = X.new_zeros(N, R, T, DV)
561             self.rec_K = X.new_zeros(N, R, T, DK)
562             # We start the recurrent sequences with optimizable
563             # initial values. No idea if it helps.
564             self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :]
565             self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :]
566
567             self.cache_Y = X.new_zeros(N, T, DM)
568
569         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
570         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
571
572         V, K = blanket(V), blanket(K)
573
574         ######################################################################
575         # Compute the recurrent state
576
577         # This is the Gating sequence that modulates the storing of
578         # the new key and value in the R pairs of the current
579         # stack. There are R independent gating values, which means
580         # that the current K/V may be stored in multiple pairs of the
581         # recurrent state, or not at all.
582
583         G = (
584             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
585         ).sigmoid()
586
587         # Clip the gating to avoid values greater than 1 when several
588         # heads hit the same row
589
590         G = G / G.sum(1, keepdim=True).clamp(min=1)
591
592         # G_star = (1 - G).log().sum(1, keepdim=True).exp()
593
594         ######################################################################
595
596         def recurrence(G, V, K):
597             # We prepare the arguments for the parallel scan
598
599             A = 1 - G.sum(dim=1)
600
601             gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
602             gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
603
604             # We start from cached values, which matters in inference
605
606             init_rec_V = self.rec_V[:, :, t0 - L : t0]
607             init_rec_K = self.rec_K[:, :, t0 - L : t0]
608
609             # Here there is a trick: Since the stack at position t is
610             # computed by updating that at position t-L, the parallel
611             # scan operates with a period of L. To do so we split the
612             # sequence indexing in two axes, the second of size L, and
613             # run the parallel scan using the first as the sequence index.
614
615             A = A.unflatten(2, (-1, L))
616             gated_V = gated_V.unflatten(2, (-1, L))
617             gated_K = gated_K.unflatten(2, (-1, L))
618
619             next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
620             next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
621
622             return next_V, next_K
623
624         #################################################################
625
626         next_V, next_K = recurrence(G, V, K)
627
628         if self.training and self.gate_dropout_proba > 0.0:
629             # G is NxHxRxT where r is the caterpillar's row.
630
631             warnings.warn("gate dropout", RuntimeWarning)
632
633             if self.gate_dropout_sync:
634                 shape_kill = (N, 1, 1)
635             else:
636                 shape_kill = (N, H, R)
637
638             # Pick a point in each of the NxHxR timeline and set this
639             # entry and the following to 1
640             kill = (
641                 torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
642                 == 0
643             ).cumsum(dim=3)
644
645             # Keep these mask for only some of the NxHxR
646             kill = kill * (
647                 torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
648             )
649
650             # The coefficient to keep are the complementary
651             mask = 1 - kill
652
653             masked_next_V, masked_next_K = recurrence(G * mask, V, K)
654
655             if self.gate_dropout_replace:
656                 next_V = next_V.detach()
657                 next_K = next_K.detach()
658
659             warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
660
661             next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
662                 1 - self.gate_dropout_proba
663             )
664             next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
665                 1 - self.gate_dropout_proba
666             )
667
668         self.rec_V[:, :, t0:t1] = next_V
669         self.rec_K[:, :, t0:t1] = next_K
670
671         ######################################################################
672         # compute the readout
673
674         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
675
676         Q = blanket(Q)
677
678         # We build tensors NxHxTxRxL where N is the sample index, H
679         # the head, T the time, R the row in the caterpillar, and L
680         # the column in the caterpillar
681
682         windowed_V = moving_window(
683             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
684         )
685
686         windowed_K = moving_window(
687             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
688         )
689
690         # We have an attention score for each of the RxL values
691
692         ar = torch.einsum(
693             "nhtd,nrtld->nhtrl",
694             Q,
695             windowed_K,
696         ) / math.sqrt(DK)
697
698         # softmax can operate only on one dimension, hence the
699         # flattening
700
701         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
702
703         ar = F.dropout(ar, self.attention_dropout, self.training)
704
705         # Compute the output for each head, flatten to concatenate
706
707         Y = torch.einsum(
708             "nhtfl,nftld->nthd",
709             ar,
710             windowed_V,
711         ).flatten(2)
712
713         # Compute the final output
714
715         Y = blanket(Y)
716
717         self.cache_Y[:, t0:t1] = Y @ self.w_O
718
719         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
720
721
722 ##############################
723
724
725 class QKVAttention(nn.Module):
726     def __init__(
727         self,
728         dim_model,
729         dim_qk,
730         dim_v,
731         nb_heads=1,
732         causal=False,
733         attention_dropout=0.0,
734         logger=print,
735         args=None,
736     ):
737         super().__init__()
738
739         def randw(*d):
740             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
741
742         self.causal = causal
743         self.attention_dropout = attention_dropout
744         self.record_attention = False
745
746         self.w_q = randw(nb_heads, dim_qk, dim_model)
747         self.w_k = randw(nb_heads, dim_qk, dim_model)
748         self.w_v = randw(nb_heads, dim_v, dim_model)
749         self.w_o = randw(dim_v * nb_heads, dim_model)
750
751     def forward(self, bs):
752         x_q = bs.x
753
754         assert (
755             self.causal or bs.complete()
756         ), "Partial evaluation is only possible for causal models"
757
758         if bs.init_cache:
759             self.cache_k = x_q.new_zeros(
760                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
761             )
762             self.cache_v = x_q.new_zeros(
763                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
764             )
765             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
766
767         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
768
769         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
770             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
771         )
772         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
773             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
774         )
775
776         a = torch.einsum(
777             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
778         ) / math.sqrt(self.w_q.size(1))
779
780         if self.causal:
781             if bs.init_cache:
782                 self.cache_attzero = (
783                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
784                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
785                 )
786             a = a.masked_fill(
787                 self.cache_attzero[
788                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
789                 ],
790                 float("-inf"),
791             )
792
793         a = a.softmax(dim=3)
794
795         if self.record_attention:
796             self.a = a
797
798         a = F.dropout(a, self.attention_dropout, self.training)
799
800         y = torch.einsum(
801             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
802         ).flatten(2)
803
804         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
805
806         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
807
808
809 ##############################
810
811
812 class MyGPT(nn.Module):
813     def __init__(
814         self,
815         vocabulary_size,
816         dim_model,
817         dim_keys,
818         dim_hidden,
819         nb_heads,
820         nb_blocks,
821         nb_lines=None,
822         caterpillar_height=None,
823         causal=False,
824         dropout=0.0,
825         len_max=1e5,
826         attention_layer="caterpillar",
827         logger=print,
828         args=None,
829     ):
830         super().__init__()
831
832         assert attention_layer in {
833             "mha",
834             "dumbrec",
835             "kvrec",
836             "caterpillar",
837         }, f"Unknown attention operator {attention_layer}."
838
839         if attention_layer == "caterpillar":
840             assert nb_lines % caterpillar_height == 0
841             self.caterpillar_length = nb_lines // caterpillar_height
842             self.caterpillar_height = caterpillar_height
843         else:
844             self.caterpillar_length = -1
845             self.caterpillar_height = -1
846
847         assert dim_model % nb_heads == 0
848
849         self.embedding = nn.Sequential(
850             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
851             AddPositionalEncoding(len_max),
852         )
853
854         trunk_blocks = []
855
856         def attlayer():
857             if attention_layer == "mha":
858                 return QKVAttention(
859                     dim_model=dim_model,
860                     dim_qk=dim_keys,
861                     dim_v=dim_model // nb_heads,
862                     nb_heads=nb_heads,
863                     causal=causal,
864                     attention_dropout=dropout,
865                     logger=logger,
866                     args=args,
867                 )
868             elif attention_layer == "dumbrec":
869                 return DumbRec(
870                     dim_model=dim_model,
871                     dim_qk=dim_keys,
872                     dim_v=dim_model // nb_heads,
873                     nb_heads=nb_heads,
874                     nb_lines=nb_lines,
875                     attention_dropout=dropout,
876                     logger=logger,
877                     args=args,
878                 )
879             elif attention_layer == "kvrec":
880                 return KVRec(
881                     dim_model=dim_model,
882                     dim_qk=dim_keys,
883                     dim_v=dim_model // nb_heads,
884                     nb_heads=nb_heads,
885                     nb_lines=nb_lines,
886                     attention_dropout=dropout,
887                     logger=logger,
888                     args=args,
889                 )
890             elif attention_layer == "caterpillar":
891                 return Caterpillar(
892                     dim_model=dim_model,
893                     dim_qk=dim_keys,
894                     dim_v=dim_model // nb_heads,
895                     nb_heads=nb_heads,
896                     caterpillar_length=self.caterpillar_length,
897                     caterpillar_height=self.caterpillar_height,
898                     attention_dropout=dropout,
899                     logger=logger,
900                     args=args,
901                 )
902             else:
903                 raise ValueError(f"Unknown attention type {attention_layer}.")
904
905         for b in range(nb_blocks):
906             trunk_blocks += [
907                 WithResidual(
908                     CacheWrapper(nn.LayerNorm((dim_model,))),
909                     attlayer(),
910                 ),
911                 WithResidual(
912                     CacheWrapper(
913                         nn.LayerNorm((dim_model,)),
914                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
915                         nn.ReLU(),
916                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
917                         nn.Dropout(dropout),
918                     ),
919                 ),
920             ]
921
922         self.trunk = nn.Sequential(*trunk_blocks)
923
924         self.readout = CacheWrapper(
925             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
926         )
927
928         with torch.no_grad():
929             for m in self.modules():
930                 if isinstance(m, nn.Embedding):
931                     m.weight.normal_(mean=0, std=2e-2)
932                 elif isinstance(m, nn.LayerNorm):
933                     m.bias.zero_()
934                     m.weight.fill_(1.0)
935
936         self.reset_inner_loss()
937
938     def forward(self, bs):
939         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
940
941         # To make the code simpler in the Caterpillar layer, we pad
942         # here. It's unclear if/how much it hurts computationaly by
943         # increasing the sequence length for the other layers
944
945         if self.caterpillar_length > 0:
946             original_nb = bs.nb
947             if bs.nb % self.caterpillar_length > 0:
948                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
949
950             bs = BracketedSequence(
951                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
952                 bs.first + self.caterpillar_length,
953                 bs.nb,
954                 bs.init_cache,
955             )
956
957         bs = self.embedding(bs)
958         bs = self.trunk(bs)
959         bs = self.readout(bs)
960
961         if self.caterpillar_length > 0:
962             bs = BracketedSequence(
963                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
964                 bs.first - self.caterpillar_length,
965                 original_nb,
966                 bs.init_cache,
967             )
968
969         return bs
970
971     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
972     # 1s where tokens should be generated. The others are kept
973     # unchanged.
974
975     def masked_inplace_autoregression(
976         self,
977         input_src,
978         ar_mask_src,
979         forbidden_tokens=None,
980         deterministic_synthesis=False,
981     ):
982         input = input_src.to(self.readout.f.weight.device)
983         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
984         to_generate = (ar_mask.sum(0) > 0).nonzero()
985         if to_generate.min() > 0:
986             self(
987                 BracketedSequence(input, 0, to_generate.min(), True)
988             )  # Needed to initialize the model's cache
989         for s in range(to_generate.min(), to_generate.max() + 1):
990             output = self(BracketedSequence(input, s, 1, s == 0)).x
991             logits = output[:, s]
992             if forbidden_tokens is not None:
993                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
994             if deterministic_synthesis:
995                 t_next = logits.argmax(1)
996             else:
997                 dist = torch.distributions.categorical.Categorical(logits=logits)
998                 t_next = dist.sample()
999             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1000
1001         input_src.copy_(input)
1002
1003     def reset_inner_loss(self):
1004         for m in self.modules():
1005             if m is not self and hasattr(m, "reset_inner_loss"):
1006                 m.reset_inner_loss()
1007
1008     def get_inner_loss(self):
1009         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1010         for m in self.modules():
1011             if m is not self and hasattr(m, "get_inner_loss"):
1012                 l += m.get_inner_loss()
1013         return l
1014
1015     def record_attention(self, v=True):
1016         for m in self.modules():
1017             if isinstance(m, QKVAttention):
1018                 m.record_attention = v
1019
1020     def retrieve_attention(self):
1021         a = []
1022         for m in self.modules():
1023             if isinstance(m, QKVAttention):
1024                 a.append(m.a)
1025         return a
1026
1027
1028 ######################################################################
1029
1030 if __name__ == "__main__":
1031     import argparse
1032
1033     import numpy as np
1034     import matplotlib.pyplot as plt
1035     import matplotlib.collections as mc
1036
1037     args = argparse.Namespace(
1038         gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
1039     )
1040
1041     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1042
1043     dim_model, dim_keys, nb_heads = 512, 64, 1
1044     dropout = 0.1
1045
1046     caterpillar = Caterpillar(
1047         dim_model=dim_model,
1048         dim_qk=dim_keys,
1049         dim_v=dim_model // nb_heads,
1050         nb_heads=nb_heads,
1051         caterpillar_length=16,
1052         caterpillar_height=32,
1053         attention_dropout=dropout,
1054         args=args,
1055     ).to(device)
1056
1057     qkv = QKVAttention(
1058         dim_model=dim_model,
1059         dim_qk=dim_keys,
1060         dim_v=dim_model // nb_heads,
1061         nb_heads=nb_heads,
1062         causal=True,
1063         attention_dropout=dropout,
1064         args=args,
1065     ).to(device)
1066
1067     linear = CacheWrapper(nn.Linear(512, 512)).to(device)
1068
1069     x = torch.randn(1, 256, dim_model)
1070
1071     x = x.to(device)
1072     x.requires_grad_()
1073
1074     ######################################################################
1075
1076     fig = plt.figure()
1077     fig.set_figheight(6)
1078     fig.set_figwidth(8)
1079
1080     ax = fig.add_subplot(1, 1, 1)
1081
1082     # ax.set_xlim(-1.5, 1.5)
1083     # ax.set_ylim(-1.5, 1.5)
1084     # ax.set(aspect=1)
1085     # ax.spines.right.set_visible(False)
1086     # ax.spines.top.set_visible(False)
1087
1088     # dt = 0.01
1089     # t = np.arange(dt, 20.0, dt)
1090     # ax.semilogx(t, np.exp(-t / 5.0))
1091     # ax.grid()
1092     ax.set_yscale("log")
1093
1094     ######################################################################
1095
1096     for label, model, thickness in [
1097         ("nn.Linear", linear, 0.2),
1098         ("mygpy.QKVAttention", qkv, 1),
1099         ("mygpt.Caterpillar", caterpillar, 2),
1100     ]:
1101         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
1102
1103         for n, p in [("input", x)] + list(model.named_parameters()):
1104             print(f"Processing {model}.{n}")
1105             data = []
1106             for t in range(y.size(1)):
1107                 sg = 0
1108                 for d in torch.randperm(y.size(2))[:8]:
1109                     sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
1110                 assert not sg.isinf().any()
1111                 assert not sg.isnan().any()
1112                 data.append([t, sg.sum().item()])
1113
1114             data = torch.tensor(data)
1115             # cx, cy = data[:, 0], data[:, 1]
1116             cy = data[:, 1].sort().values
1117             cx = torch.linspace(0, 1, cy.size(0))
1118             ax.plot(
1119                 cx, cy, label=label + "." + n, linewidth=thickness
1120             )  # , color='gray', label='Input')
1121
1122     # ax.legend(frameon=False, loc="top right")
1123
1124     # Put a legend to the right of the current axis
1125     box = ax.get_position()
1126     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1127     ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1128
1129     filename = "plot.pdf"
1130     print(f"saving {filename}")
1131     fig.savefig(filename, bbox_inches="tight")
1132
1133     # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
1134     # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
1135     # plt.show()
1136
1137     exit(0)
1138
1139     ######################################################################
1140
1141     m = Caterpillar(
1142         dim_model=4,
1143         dim_qk=3,
1144         dim_v=7,
1145         nb_heads=1,
1146         caterpillar_length=7,
1147         caterpillar_height=3,
1148         attention_dropout=0.0,
1149     )
1150
1151     m.reset_inner_loss()
1152     x = torch.randn(1, 21 + 2 * 7, 4)
1153     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1154     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1155     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1156     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1157     print((y1 - y2).abs().max())
1158     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1159     exit(0)
1160
1161     vocabulary_size = 128
1162     x = torch.randint(vocabulary_size, (6, 1024))
1163
1164     model = MyGPT(
1165         vocabulary_size=vocabulary_size,
1166         dim_model=512,
1167         dim_keys=64,
1168         dim_hidden=2048,
1169         nb_heads=8,
1170         nb_lines=128,
1171         nb_blocks=12,
1172         dropout=0.1,
1173         causal=True,
1174     )
1175
1176     x = x.to(device)
1177     model.to(device)
1178
1179     import time, sys
1180
1181     # import torchvision.models as models
1182     # from torch.profiler import profile, record_function, ProfilerActivity
1183
1184     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1185     # with record_function("model_inference"):
1186
1187     model.eval()
1188     for i in range(3):
1189         start_time = time.perf_counter()
1190         for k in range(10):
1191             model(BracketedSequence(x))
1192         duration = time.perf_counter() - start_time
1193         print(duration)
1194         sys.stdout.flush()
1195
1196     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1197     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1198
1199     # print("##############################################################")
1200     # y2 = torch.randn_like(y1)
1201     # for s in range(x.size(1)):
1202     # z = model(BracketedSequence(x, s, 1))
1203     # y2[:, s : s + 1] = z.slice()
1204
1205     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1206
1207 ######################################################################