Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d):
480             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
481
482         self.caterpillar_length = caterpillar_length
483         self.caterpillar_height = caterpillar_height
484         self.attention_dropout = attention_dropout
485
486         self.proba_flashback = 0.0
487         self.proba_gate_dropout = 0.0
488
489         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
490         self.b_G = nn.Parameter(
491             torch.full(
492                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
493             )
494         )
495
496         self.w_K = randw(nb_heads, dim_qk, dim_model)
497         self.w_V = randw(nb_heads, dim_v, dim_model)
498         self.w_Q = randw(nb_heads, dim_qk, dim_model)
499         self.w_O = randw(dim_v * nb_heads, dim_model)
500
501         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
502         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
503
504     def reset_inner_loss(self):
505         self.acc_attention = 0
506         self.acc_nb = 0
507
508     def get_inner_loss(self):
509         # warnings.warn("l2 regularization", RuntimeWarning)
510         # return (self.acc_attention / self.acc_nb).pow(2).sum()
511         return torch.tensor([0], device=self.w_Q.device)
512
513     def forward(self, bs):
514         # Dimensions to make the source a bit clearer, that's needed
515
516         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
517
518         N = bs.x.size(0)
519         T = bs.x.size(1)
520         H = self.w_V.size(0)
521         DV = self.w_V.size(1)
522         DK = self.w_K.size(1)
523         DM = self.w_O.size(1)
524         CH = self.caterpillar_height
525         CL = self.caterpillar_length
526
527         assert (
528             t0 >= CL and (t1 - t0) % CL == 0
529         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
530
531         # We cache values to deal efficiently with auto-regression
532
533         if bs.init_cache:
534             self.rec_V = X.new_zeros(N, CH, T, DV)
535             self.rec_K = X.new_zeros(N, CH, T, DK)
536             # We start the recurrent sequences with optimizable
537             # initial values. No idea if it helps.
538             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
539             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
540
541             self.cache_Y = X.new_zeros(N, T, DM)
542
543         ######################################################################
544         # Compute the recurrent state
545
546         # This is the Gating sequence that modulates the storing of
547         # the new key and value in the CH pairs of the current
548         # stack. The CH gating values are independent, which means
549         # that the current K/V could be stored in multiple pairs of the
550         # recurrent state, or not at all.
551
552         G = (
553             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
554         ).sigmoid()
555
556         if self.training and self.proba_gate_dropout > 0.0:
557             warnings.warn("gate droupout", RuntimeWarning)
558             epsilon = 0.5
559
560         # That was a bad idea
561         # G = F.dropout(G, self.attention_dropout, self.training)
562
563         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
564         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
565
566         # We prepare the arguments for the parallel scan
567
568         # Clip the gating
569         warnings.warn("gating clipping", RuntimeWarning)
570         G = G / G.sum(1, keepdim=True).clamp(min=1)
571
572         A = 1 - G.sum(1)
573         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
574         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
575
576         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
577         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
578
579         # Here there is a trick: Since the stack at time t is computed
580         # by updating that at time t-L, the parallel scan operates
581         # with a period of L. To do so we split the time indexing in
582         # two axes, the second of size CL, and run the parallel scan
583         # using the other as the sequence index.
584
585         A = A.unflatten(2, (-1, CL))
586         gated_V = gated_V.unflatten(2, (-1, CL))
587         gated_K = gated_K.unflatten(2, (-1, CL))
588
589         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
590         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
591
592         # Put back the sequence index
593
594         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
595         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
596
597         if self.training and self.proba_flashback > 0.0:
598             warnings.warn("flash back", RuntimeWarning)
599             # This piece of code makes the assumption that there is
600             # nothing informative before t0, otherwise we'd have to
601             # implement a cache for V and K too. This should not be
602             # too much of a problem since this is used only during
603             # train, where full sequence are available
604
605             n = torch.arange(N, device=X.device)[:, None, None, None]
606             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
607             dv = torch.arange(DV, device=X.device)[None, None, None, :]
608             dk = torch.arange(DK, device=X.device)[None, None, None, :]
609
610             u = (
611                 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
612             ) * CL
613
614             src_time = t - u - t0
615             src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
616
617             mask = (
618                 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
619             ).long()
620
621             self.rec_V[:, :, t0:t1] = (
622                 mask * V[n, src_head, src_time, dv]
623                 + (1 - mask) * self.rec_V[:, :, t0:t1]
624             )
625
626             self.rec_K[:, :, t0:t1] = (
627                 mask * K[n, src_head, src_time, dk]
628                 + (1 - mask) * self.rec_K[:, :, t0:t1]
629             )
630
631         ######################################################################
632         # compute the readout
633
634         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
635
636         # We build tensors NxHxTxFxL where N is the sample index, H
637         # the head, T the time, F the row in the caterpillar, and L
638         # the column in the caterpillar
639
640         windowed_V = moving_window(
641             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
642         )
643
644         windowed_K = moving_window(
645             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
646         )
647
648         # We have an attention score for each of the CHxCL values
649
650         ar = torch.einsum(
651             "nhtd,nftld->nhtfl",
652             Q,
653             windowed_K,
654         ) / math.sqrt(DK)
655
656         # softmax can operate only on one dimension, hence the
657         # flattening
658
659         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
660
661         ar = F.dropout(ar, self.attention_dropout, self.training)
662
663         # Compute the output for each head, flatten to concatenate
664
665         Y = torch.einsum(
666             "nhtfl,nftld->nthd",
667             ar,
668             windowed_V,
669         ).flatten(2)
670
671         # Compute the final output
672
673         self.cache_Y[:, t0:t1] = Y @ self.w_O
674
675         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
676
677
678 ##############################
679
680
681 class QKVAttention(nn.Module):
682     def __init__(
683         self,
684         dim_model,
685         dim_qk,
686         dim_v,
687         nb_heads=1,
688         causal=False,
689         attention_dropout=0.0,
690     ):
691         super().__init__()
692
693         def randw(*d):
694             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
695
696         self.causal = causal
697         self.attention_dropout = attention_dropout
698         self.record_attention = False
699
700         self.w_q = randw(nb_heads, dim_qk, dim_model)
701         self.w_k = randw(nb_heads, dim_qk, dim_model)
702         self.w_v = randw(nb_heads, dim_v, dim_model)
703         self.w_o = randw(dim_v * nb_heads, dim_model)
704
705     def forward(self, bs):
706         x_q = bs.x
707
708         assert (
709             self.causal or bs.complete()
710         ), "Partial evaluation is only possible for causal models"
711
712         if bs.init_cache:
713             self.cache_k = x_q.new_zeros(
714                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
715             )
716             self.cache_v = x_q.new_zeros(
717                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
718             )
719             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
720
721         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
722
723         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
724             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
725         )
726         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
727             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
728         )
729
730         a = torch.einsum(
731             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
732         ) / math.sqrt(self.w_q.size(1))
733
734         if self.causal:
735             if bs.init_cache:
736                 self.cache_attzero = (
737                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
738                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
739                 )
740             a = a.masked_fill(
741                 self.cache_attzero[
742                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
743                 ],
744                 float("-inf"),
745             )
746
747         a = a.softmax(dim=3)
748
749         if self.record_attention:
750             self.a = a
751
752         a = F.dropout(a, self.attention_dropout, self.training)
753
754         y = torch.einsum(
755             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
756         ).flatten(2)
757
758         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
759
760         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
761
762
763 ##############################
764
765
766 class MyGPT(nn.Module):
767     def __init__(
768         self,
769         vocabulary_size,
770         dim_model,
771         dim_keys,
772         dim_hidden,
773         nb_heads,
774         nb_blocks,
775         nb_lines=None,
776         caterpillar_height=None,
777         dim_rec_v=-1,
778         causal=False,
779         dropout=0.0,
780         len_max=1e5,
781         attention_layer="kvrec",
782     ):
783         super().__init__()
784
785         assert attention_layer in {
786             "mha",
787             "dumbrec",
788             "kvrec",
789             "caterpillar",
790         }, f"Unknown attention operator {attention_layer}."
791
792         if attention_layer == "caterpillar":
793             assert nb_lines % caterpillar_height == 0
794             self.caterpillar_length = nb_lines // caterpillar_height
795             self.caterpillar_height = caterpillar_height
796         else:
797             self.caterpillar_length = -1
798             self.caterpillar_height = -1
799
800         assert dim_model % nb_heads == 0
801
802         self.embedding = nn.Sequential(
803             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
804             AddPositionalEncoding(len_max),
805         )
806
807         trunk_blocks = []
808
809         def attlayer():
810             if attention_layer == "mha":
811                 return QKVAttention(
812                     dim_model=dim_model,
813                     dim_qk=dim_keys,
814                     dim_v=dim_model // nb_heads,
815                     nb_heads=nb_heads,
816                     causal=causal,
817                     attention_dropout=dropout,
818                 )
819             elif attention_layer == "dumbrec":
820                 return DumbRec(
821                     dim_model=dim_model,
822                     dim_qk=dim_keys,
823                     dim_v=dim_rec_v,
824                     nb_heads=nb_heads,
825                     nb_lines=nb_lines,
826                     attention_dropout=dropout,
827                 )
828             elif attention_layer == "kvrec":
829                 return KVRec(
830                     dim_model=dim_model,
831                     dim_qk=dim_keys,
832                     dim_v=dim_rec_v,
833                     nb_heads=nb_heads,
834                     nb_lines=nb_lines,
835                     attention_dropout=dropout,
836                 )
837             elif attention_layer == "caterpillar":
838                 return Caterpillar(
839                     dim_model=dim_model,
840                     dim_qk=dim_keys,
841                     dim_v=dim_rec_v,
842                     nb_heads=nb_heads,
843                     caterpillar_length=self.caterpillar_length,
844                     caterpillar_height=self.caterpillar_height,
845                     attention_dropout=dropout,
846                 )
847             else:
848                 raise ValueError(f"Unknown attention type {attention_layer}.")
849
850         for b in range(nb_blocks):
851             trunk_blocks += [
852                 WithResidual(
853                     CacheWrapper(nn.LayerNorm((dim_model,))),
854                     attlayer(),
855                 ),
856                 WithResidual(
857                     CacheWrapper(
858                         nn.LayerNorm((dim_model,)),
859                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
860                         nn.ReLU(),
861                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
862                         nn.Dropout(dropout),
863                     ),
864                 ),
865             ]
866
867         self.trunk = nn.Sequential(*trunk_blocks)
868
869         self.readout = CacheWrapper(
870             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
871         )
872
873         with torch.no_grad():
874             for m in self.modules():
875                 if isinstance(m, nn.Embedding):
876                     m.weight.normal_(mean=0, std=2e-2)
877                 elif isinstance(m, nn.LayerNorm):
878                     m.bias.zero_()
879                     m.weight.fill_(1.0)
880
881         self.reset_inner_loss()
882
883     def forward(self, bs):
884         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
885
886         # To make the code simpler in the Caterpillar layer, we pad
887         # here. It's unclear if/how much it hurts computationaly by
888         # increasing the sequence length for the other layers
889
890         if self.caterpillar_length > 0:
891             original_nb = bs.nb
892             if bs.nb % self.caterpillar_length > 0:
893                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
894
895             bs = BracketedSequence(
896                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
897                 bs.first + self.caterpillar_length,
898                 bs.nb,
899                 bs.init_cache,
900             )
901
902         bs = self.embedding(bs)
903         bs = self.trunk(bs)
904         bs = self.readout(bs)
905
906         if self.caterpillar_length > 0:
907             bs = BracketedSequence(
908                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
909                 bs.first - self.caterpillar_length,
910                 original_nb,
911                 bs.init_cache,
912             )
913
914         return bs
915
916     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
917     # 1s where tokens should be generated. The others are kept
918     # unchanged.
919
920     def masked_inplace_autoregression(
921         self,
922         input_src,
923         ar_mask_src,
924         forbidden_tokens=None,
925         deterministic_synthesis=False,
926     ):
927         input = input_src.to(self.readout.f.weight.device)
928         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
929         to_generate = (ar_mask.sum(0) > 0).nonzero()
930         if to_generate.min() > 0:
931             self(
932                 BracketedSequence(input, 0, to_generate.min(), True)
933             )  # Needed to initialize the model's cache
934         for s in range(to_generate.min(), to_generate.max() + 1):
935             output = self(BracketedSequence(input, s, 1, s == 0)).x
936             logits = output[:, s]
937             if forbidden_tokens is not None:
938                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
939             if deterministic_synthesis:
940                 t_next = logits.argmax(1)
941             else:
942                 dist = torch.distributions.categorical.Categorical(logits=logits)
943                 t_next = dist.sample()
944             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
945
946         input_src.copy_(input)
947
948     def reset_inner_loss(self):
949         for m in self.modules():
950             if m is not self and hasattr(m, "reset_inner_loss"):
951                 m.reset_inner_loss()
952
953     def get_inner_loss(self):
954         l = torch.tensor([0.0], device=self.readout.f.weight.device)
955         for m in self.modules():
956             if m is not self and hasattr(m, "get_inner_loss"):
957                 l += m.get_inner_loss()
958         return l
959
960     def record_attention(self, v=True):
961         for m in self.modules():
962             if isinstance(m, QKVAttention):
963                 m.record_attention = v
964
965     def retrieve_attention(self):
966         a = []
967         for m in self.modules():
968             if isinstance(m, QKVAttention):
969                 a.append(m.a)
970         return a
971
972
973 ######################################################################
974
975 if __name__ == "__main__":
976     print("Basic check.")
977
978     m = Caterpillar(
979         dim_model=4,
980         dim_qk=3,
981         dim_v=7,
982         nb_heads=1,
983         caterpillar_length=7,
984         caterpillar_height=3,
985         attention_dropout=0.0,
986     )
987
988     m.reset_inner_loss()
989     x = torch.randn(1, 21 + 2 * 7, 4)
990     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
991     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
992     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
993     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
994     print((y1 - y2).abs().max())
995     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
996     exit(0)
997
998     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
999
1000     vocabulary_size = 128
1001     x = torch.randint(vocabulary_size, (6, 1024))
1002
1003     model = MyGPT(
1004         vocabulary_size=vocabulary_size,
1005         dim_model=512,
1006         dim_keys=64,
1007         dim_hidden=2048,
1008         nb_heads=8,
1009         nb_lines=128,
1010         nb_blocks=12,
1011         dropout=0.1,
1012         causal=True,
1013     )
1014
1015     x = x.to(device)
1016     model.to(device)
1017
1018     import time, sys
1019
1020     # import torchvision.models as models
1021     # from torch.profiler import profile, record_function, ProfilerActivity
1022
1023     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1024     # with record_function("model_inference"):
1025
1026     model.eval()
1027     for i in range(3):
1028         start_time = time.perf_counter()
1029         for k in range(10):
1030             model(BracketedSequence(x))
1031         duration = time.perf_counter() - start_time
1032         print(duration)
1033         sys.stdout.flush()
1034
1035     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1036     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1037
1038     # print("##############################################################")
1039     # y2 = torch.randn_like(y1)
1040     # for s in range(x.size(1)):
1041     # z = model(BracketedSequence(x, s, 1))
1042     # y2[:, s : s + 1] = z.slice()
1043
1044     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1045
1046 ######################################################################