Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d):
480             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
481
482         self.caterpillar_length = caterpillar_length
483         self.caterpillar_height = caterpillar_height
484         self.attention_dropout = attention_dropout
485
486         self.proba_flashback = 0.0
487         self.proba_gate_dropout = 0.0
488
489         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
490         self.b_G = nn.Parameter(
491             torch.full(
492                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
493             )
494         )
495
496         self.w_K = randw(nb_heads, dim_qk, dim_model)
497         self.w_V = randw(nb_heads, dim_v, dim_model)
498         self.w_Q = randw(nb_heads, dim_qk, dim_model)
499         self.w_O = randw(dim_v * nb_heads, dim_model)
500
501         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
502         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
503
504     def reset_inner_loss(self):
505         self.acc_attention = 0
506         self.acc_nb = 0
507
508     def get_inner_loss(self):
509         # warnings.warn("l2 regularization", RuntimeWarning)
510         # return (self.acc_attention / self.acc_nb).pow(2).sum()
511         return torch.tensor([0], device=self.w_Q.device)
512
513     def forward(self, bs):
514         # Dimensions to make the source a bit clearer, that's needed
515
516         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
517
518         N = bs.x.size(0)
519         T = bs.x.size(1)
520         H = self.w_V.size(0)
521         DV = self.w_V.size(1)
522         DK = self.w_K.size(1)
523         DM = self.w_O.size(1)
524         CH = self.caterpillar_height
525         CL = self.caterpillar_length
526
527         assert (
528             t0 >= CL and (t1 - t0) % CL == 0
529         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
530
531         # We cache values to deal efficiently with auto-regression
532
533         if bs.init_cache:
534             self.rec_V = X.new_zeros(N, CH, T, DV)
535             self.rec_K = X.new_zeros(N, CH, T, DK)
536             # We start the recurrent sequences with optimizable
537             # initial values. No idea if it helps.
538             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
539             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
540
541             self.cache_Y = X.new_zeros(N, T, DM)
542
543         ######################################################################
544         # Compute the recurrent state
545
546         # This is the Gating sequence that modulates the storing of
547         # the new key and value in the CH pairs of the current
548         # stack. The CH gating values are independent, which means
549         # that the current K/V could be stored in multiple pairs of the
550         # recurrent state, or not at all.
551
552         G = (
553             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
554         ).sigmoid()
555
556         if self.training and self.proba_gate_dropout > 0.0:
557             warnings.warn("gate droupout", RuntimeWarning)
558             epsilon = 0.5
559
560         # That was a bad idea
561         # G = F.dropout(G, self.attention_dropout, self.training)
562
563         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
564         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
565
566         # We prepare the arguments for the parallel scan
567
568         # Clip the gating
569         warnings.warn("gating clipping", RuntimeWarning)
570         G = G / G.sum(1, keepdim=True).clamp(min=1)
571
572         A = 1 - G.sum(1)
573         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
574         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
575
576         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
577         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
578
579         # Here there is a trick: Since the stack at time t is computed
580         # by updating that at time t-L, the parallel scan operates
581         # with a period of L. To do so we split the time indexing in
582         # two axes, the second of size CL, and run the parallel scan
583         # using the other as the sequence index.
584
585         A = A.unflatten(2, (-1, CL))
586         gated_V = gated_V.unflatten(2, (-1, CL))
587         gated_K = gated_K.unflatten(2, (-1, CL))
588
589         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
590         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
591
592         # Put back the sequence index
593
594         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
595         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
596
597         if self.training and self.proba_flashback > 0.0:
598             warnings.warn("flash back", RuntimeWarning)
599             # This piece of code makes the assumption that there is
600             # nothing informative before t0, otherwise we'd have to
601             # implement a cache for V and K too. This should not be
602             # too much of a problem since this is used only during
603             # train, where full sequence are available
604
605             n = torch.arange(N, device=X.device)[:, None, None, None]
606             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
607             dv = torch.arange(DV, device=X.device)[None, None, None, :]
608             dk = torch.arange(DK, device=X.device)[None, None, None, :]
609
610             u = (
611                 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
612             ) * CL
613
614             src_time = t - u - t0
615             src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
616
617             mask = (
618                 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
619             ).long()
620
621             self.rec_V[:, :, t0:t1] = (
622                 mask * V[n, src_head, src_time, dv]
623                 + (1 - mask) * self.rec_V[:, :, t0:t1]
624             )
625
626             self.rec_K[:, :, t0:t1] = (
627                 mask * K[n, src_head, src_time, dk]
628                 + (1 - mask) * self.rec_K[:, :, t0:t1]
629             )
630
631         ######################################################################
632         # compute the readout
633
634         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
635
636         # We build tensors NxHxTxFxL where N is the sample index, H
637         # the head, T the time, F the row in the caterpillar, and L
638         # the column in the caterpillar
639
640         windowed_V = moving_window(
641             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
642         )
643
644         windowed_K = moving_window(
645             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
646         )
647
648         # We have an attention score for each of the CHxCL values
649
650         ar = torch.einsum(
651             "nhtd,nftld->nhtfl",
652             Q,
653             windowed_K,
654         ) / math.sqrt(DK)
655
656         # softmax can operate only on one dimension, hence the
657         # flattening
658
659         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
660
661         ar = F.dropout(ar, self.attention_dropout, self.training)
662
663         # Compute the output for each head, flatten to concatenate
664
665         Y = torch.einsum(
666             "nhtfl,nftld->nthd",
667             ar,
668             windowed_V,
669         ).flatten(2)
670
671         # Compute the final output
672
673         self.cache_Y[:, t0:t1] = Y @ self.w_O
674
675         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
676
677
678 ##############################
679
680
681 class QKVAttention(nn.Module):
682     def __init__(
683         self,
684         dim_model,
685         dim_qk,
686         dim_v,
687         nb_heads=1,
688         causal=False,
689         attention_dropout=0.0,
690     ):
691         super().__init__()
692
693         def randw(*d):
694             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
695
696         self.causal = causal
697         self.attention_dropout = attention_dropout
698         self.record_attention = False
699
700         self.w_q = randw(nb_heads, dim_qk, dim_model)
701         self.w_k = randw(nb_heads, dim_qk, dim_model)
702         self.w_v = randw(nb_heads, dim_v, dim_model)
703         self.w_o = randw(dim_v * nb_heads, dim_model)
704
705     def forward(self, bs):
706         x_q = bs.x
707
708         assert (
709             self.causal or bs.complete()
710         ), "Partial evaluation is only possible for causal models"
711
712         if bs.init_cache:
713             self.cache_k = x_q.new_zeros(
714                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
715             )
716             self.cache_v = x_q.new_zeros(
717                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
718             )
719             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
720
721         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
722
723         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
724             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
725         )
726         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
727             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
728         )
729
730         a = torch.einsum(
731             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
732         ) / math.sqrt(self.w_q.size(1))
733
734         if self.causal:
735             if bs.init_cache:
736                 self.cache_attzero = (
737                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
738                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
739                 )
740             a = a.masked_fill(
741                 self.cache_attzero[
742                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
743                 ],
744                 float("-inf"),
745             )
746
747         a = a.softmax(dim=3)
748
749         if self.record_attention:
750             self.a = a
751
752         a = F.dropout(a, self.attention_dropout, self.training)
753
754         y = torch.einsum(
755             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
756         ).flatten(2)
757
758         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
759
760         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
761
762
763 ##############################
764
765
766 class MyGPT(nn.Module):
767     def __init__(
768         self,
769         vocabulary_size,
770         dim_model,
771         dim_keys,
772         dim_hidden,
773         nb_heads,
774         nb_blocks,
775         nb_lines=None,
776         caterpillar_height=None,
777         causal=False,
778         dropout=0.0,
779         len_max=1e5,
780         attention_layer="kvrec",
781     ):
782         super().__init__()
783
784         assert attention_layer in {
785             "mha",
786             "dumbrec",
787             "kvrec",
788             "caterpillar",
789         }, f"Unknown attention operator {attention_layer}."
790
791         if attention_layer == "caterpillar":
792             assert nb_lines % caterpillar_height == 0
793             self.caterpillar_length = nb_lines // caterpillar_height
794             self.caterpillar_height = caterpillar_height
795         else:
796             self.caterpillar_length = -1
797             self.caterpillar_height = -1
798
799         assert dim_model % nb_heads == 0
800
801         self.embedding = nn.Sequential(
802             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
803             AddPositionalEncoding(len_max),
804         )
805
806         trunk_blocks = []
807
808         def attlayer():
809             if attention_layer == "mha":
810                 return QKVAttention(
811                     dim_model=dim_model,
812                     dim_qk=dim_keys,
813                     dim_v=dim_model // nb_heads,
814                     nb_heads=nb_heads,
815                     causal=causal,
816                     attention_dropout=dropout,
817                 )
818             elif attention_layer == "dumbrec":
819                 return DumbRec(
820                     dim_model=dim_model,
821                     dim_qk=dim_keys,
822                     dim_v=dim_model // nb_heads,
823                     nb_heads=nb_heads,
824                     nb_lines=nb_lines,
825                     attention_dropout=dropout,
826                 )
827             elif attention_layer == "kvrec":
828                 return KVRec(
829                     dim_model=dim_model,
830                     dim_qk=dim_keys,
831                     dim_v=dim_model // nb_heads,
832                     nb_heads=nb_heads,
833                     nb_lines=nb_lines,
834                     attention_dropout=dropout,
835                 )
836             elif attention_layer == "caterpillar":
837                 return Caterpillar(
838                     dim_model=dim_model,
839                     dim_qk=dim_keys,
840                     dim_v=dim_model // nb_heads,
841                     nb_heads=nb_heads,
842                     caterpillar_length=self.caterpillar_length,
843                     caterpillar_height=self.caterpillar_height,
844                     attention_dropout=dropout,
845                 )
846             else:
847                 raise ValueError(f"Unknown attention type {attention_layer}.")
848
849         for b in range(nb_blocks):
850             trunk_blocks += [
851                 WithResidual(
852                     CacheWrapper(nn.LayerNorm((dim_model,))),
853                     attlayer(),
854                 ),
855                 WithResidual(
856                     CacheWrapper(
857                         nn.LayerNorm((dim_model,)),
858                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
859                         nn.ReLU(),
860                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
861                         nn.Dropout(dropout),
862                     ),
863                 ),
864             ]
865
866         self.trunk = nn.Sequential(*trunk_blocks)
867
868         self.readout = CacheWrapper(
869             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
870         )
871
872         with torch.no_grad():
873             for m in self.modules():
874                 if isinstance(m, nn.Embedding):
875                     m.weight.normal_(mean=0, std=2e-2)
876                 elif isinstance(m, nn.LayerNorm):
877                     m.bias.zero_()
878                     m.weight.fill_(1.0)
879
880         self.reset_inner_loss()
881
882     def forward(self, bs):
883         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
884
885         # To make the code simpler in the Caterpillar layer, we pad
886         # here. It's unclear if/how much it hurts computationaly by
887         # increasing the sequence length for the other layers
888
889         if self.caterpillar_length > 0:
890             original_nb = bs.nb
891             if bs.nb % self.caterpillar_length > 0:
892                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
893
894             bs = BracketedSequence(
895                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
896                 bs.first + self.caterpillar_length,
897                 bs.nb,
898                 bs.init_cache,
899             )
900
901         bs = self.embedding(bs)
902         bs = self.trunk(bs)
903         bs = self.readout(bs)
904
905         if self.caterpillar_length > 0:
906             bs = BracketedSequence(
907                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
908                 bs.first - self.caterpillar_length,
909                 original_nb,
910                 bs.init_cache,
911             )
912
913         return bs
914
915     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
916     # 1s where tokens should be generated. The others are kept
917     # unchanged.
918
919     def masked_inplace_autoregression(
920         self,
921         input_src,
922         ar_mask_src,
923         forbidden_tokens=None,
924         deterministic_synthesis=False,
925     ):
926         input = input_src.to(self.readout.f.weight.device)
927         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
928         to_generate = (ar_mask.sum(0) > 0).nonzero()
929         if to_generate.min() > 0:
930             self(
931                 BracketedSequence(input, 0, to_generate.min(), True)
932             )  # Needed to initialize the model's cache
933         for s in range(to_generate.min(), to_generate.max() + 1):
934             output = self(BracketedSequence(input, s, 1, s == 0)).x
935             logits = output[:, s]
936             if forbidden_tokens is not None:
937                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
938             if deterministic_synthesis:
939                 t_next = logits.argmax(1)
940             else:
941                 dist = torch.distributions.categorical.Categorical(logits=logits)
942                 t_next = dist.sample()
943             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
944
945         input_src.copy_(input)
946
947     def reset_inner_loss(self):
948         for m in self.modules():
949             if m is not self and hasattr(m, "reset_inner_loss"):
950                 m.reset_inner_loss()
951
952     def get_inner_loss(self):
953         l = torch.tensor([0.0], device=self.readout.f.weight.device)
954         for m in self.modules():
955             if m is not self and hasattr(m, "get_inner_loss"):
956                 l += m.get_inner_loss()
957         return l
958
959     def record_attention(self, v=True):
960         for m in self.modules():
961             if isinstance(m, QKVAttention):
962                 m.record_attention = v
963
964     def retrieve_attention(self):
965         a = []
966         for m in self.modules():
967             if isinstance(m, QKVAttention):
968                 a.append(m.a)
969         return a
970
971
972 ######################################################################
973
974 if __name__ == "__main__":
975     print("Basic check.")
976
977     m = Caterpillar(
978         dim_model=4,
979         dim_qk=3,
980         dim_v=7,
981         nb_heads=1,
982         caterpillar_length=7,
983         caterpillar_height=3,
984         attention_dropout=0.0,
985     )
986
987     m.reset_inner_loss()
988     x = torch.randn(1, 21 + 2 * 7, 4)
989     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
990     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
991     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
992     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
993     print((y1 - y2).abs().max())
994     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
995     exit(0)
996
997     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
998
999     vocabulary_size = 128
1000     x = torch.randint(vocabulary_size, (6, 1024))
1001
1002     model = MyGPT(
1003         vocabulary_size=vocabulary_size,
1004         dim_model=512,
1005         dim_keys=64,
1006         dim_hidden=2048,
1007         nb_heads=8,
1008         nb_lines=128,
1009         nb_blocks=12,
1010         dropout=0.1,
1011         causal=True,
1012     )
1013
1014     x = x.to(device)
1015     model.to(device)
1016
1017     import time, sys
1018
1019     # import torchvision.models as models
1020     # from torch.profiler import profile, record_function, ProfilerActivity
1021
1022     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1023     # with record_function("model_inference"):
1024
1025     model.eval()
1026     for i in range(3):
1027         start_time = time.perf_counter()
1028         for k in range(10):
1029             model(BracketedSequence(x))
1030         duration = time.perf_counter() - start_time
1031         print(duration)
1032         sys.stdout.flush()
1033
1034     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1035     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1036
1037     # print("##############################################################")
1038     # y2 = torch.randn_like(y1)
1039     # for s in range(x.size(1)):
1040     # z = model(BracketedSequence(x, s, 1))
1041     # y2[:, s : s + 1] = z.slice()
1042
1043     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1044
1045 ######################################################################