Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193     ):
194         super().__init__()
195
196         def randw(*d):
197             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198
199         self.nb_lines = nb_lines
200         self.attention_dropout = attention_dropout
201
202         self.k_star = randw(nb_lines, dim_qk)
203
204         self.w_qw = randw(nb_heads, dim_qk, dim_model)
205         self.w_qr = randw(nb_heads, dim_qk, dim_model)
206         # self.w_k = randw(nb_heads, dim_qk, dim_model)
207         self.w_v = randw(nb_heads, dim_v, dim_model)
208         self.w_o = randw(dim_v * nb_heads, dim_model)
209
210     def reset_inner_loss(self):
211         self.acc_attention = 0
212         self.acc_nb = 0
213
214     def get_inner_loss(self):
215         warnings.warn("l2 regularization", RuntimeWarning)
216         return (self.acc_attention / self.acc_nb).pow(2).sum()
217         # return torch.tensor([0], device=self.w_qw.device)
218
219     def forward(self, bs):
220         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
221
222         if bs.init_cache:
223             self.rec_v = x_q.new_zeros(
224                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225             )
226             # self.rec_k = x_q.new_zeros(
227             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228             # )
229             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230
231         ######################################################################
232         # Prepare the keys
233
234         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235
236         warnings.warn("rotating key barrel", RuntimeWarning)
237         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
238         t_barrel = torch.arange(t0, t1, device=k_star.device)
239         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240         l_barrel = (
241             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242         ) % k_star.size(0)
243         k_star = k_star[l_barrel, t_barrel]
244
245         ######################################################################
246         # Compute the recurrent state
247
248         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249
250         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
251         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
252
253         aw = torch.einsum(
254             "nhtd,ltd->nhlt",
255             qw,
256             k_star,
257         ) / math.sqrt(self.w_qw.size(1))
258
259         aw = aw.softmax(dim=2)  # nhlt
260
261         if self.train:
262             self.acc_attention += aw.sum(dim=(0, 1, 3))
263             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264
265         aw = F.dropout(aw, self.attention_dropout, self.training)
266
267         A = 1 - aw.sum(dim=1)  # nlt
268
269         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
270         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
271
272         if t0 == 0:
273             V0 = None
274             # K0 = None
275         else:
276             V0 = self.rec_v[:, :, t0 - 1]
277             # K0 = self.rec_k[:, :, t0 - 1]
278
279         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
280         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281
282         ######################################################################
283         # compute the readout
284
285         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
286
287         ar = torch.einsum(
288             "nhtd,ld->nhlt",
289             qr,
290             # self.rec_k[:, :, t0:t1],
291             self.k_star,
292         ) / math.sqrt(self.w_qr.size(1))
293
294         ar = ar.softmax(dim=2)  # nhlt
295
296         ar = F.dropout(ar, self.attention_dropout, self.training)
297
298         y = torch.einsum(
299             "nhlt,nltd->nthd",
300             ar,
301             self.rec_v[:, :, t0:t1],
302         ).flatten(2)
303
304         self.cache_y[:, t0:t1] = y @ self.w_o
305
306         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
307
308
309 ##############################
310
311
312 class KVRec(nn.Module):
313     def __init__(
314         self,
315         dim_model,
316         dim_qk,
317         dim_v,
318         nb_heads,
319         nb_lines,
320         attention_dropout=0.0,
321         len_max=1e5,
322     ):
323         super().__init__()
324
325         def randw(*d):
326             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327
328         self.nb_lines = nb_lines
329         self.attention_dropout = attention_dropout
330
331         self.k_star = randw(nb_lines, dim_qk)
332
333         self.w_qw = randw(nb_heads, dim_qk, dim_model)
334         self.w_qr = randw(nb_heads, dim_qk, dim_model)
335         self.w_k = randw(nb_heads, dim_qk, dim_model)
336         self.w_v = randw(nb_heads, dim_v, dim_model)
337         self.w_o = randw(dim_v * nb_heads, dim_model)
338
339     def reset_inner_loss(self):
340         self.acc_attention = 0
341         self.acc_nb = 0
342
343     def get_inner_loss(self):
344         warnings.warn("l2 regularization", RuntimeWarning)
345         return (self.acc_attention / self.acc_nb).pow(2).sum()
346         # return torch.tensor([0], device=self.w_qw.device)
347         # warnings.warn("side regularization", RuntimeWarning)
348         # return (
349         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350         # )
351         # return torch.tensor([0], device=self.w_qw.device)
352
353     def forward(self, bs):
354         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355
356         if bs.init_cache:
357             self.rec_v = x_q.new_zeros(
358                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
359             )
360             self.rec_k = x_q.new_zeros(
361                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
362             )
363             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
364
365         ######################################################################
366         # Prepare the keys
367
368         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
369
370         warnings.warn("rotating key barrel", RuntimeWarning)
371         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
372         t_barrel = torch.arange(t0, t1, device=k_star.device)
373         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
374         l_barrel = (
375             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
376         ) % k_star.size(0)
377         k_star = k_star[l_barrel, t_barrel]
378
379         ######################################################################
380         # Compute the recurrent state
381
382         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
383
384         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
385         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
386
387         aw = torch.einsum(
388             "nhtd,ltd->nhlt",
389             qw,
390             k_star,
391         ) / math.sqrt(self.w_qw.size(1))
392
393         aw = aw.softmax(dim=2)  # nhlt
394
395         if self.train:
396             # We want all the memory lines to be used similarly
397             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
398             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
399
400         aw = F.dropout(aw, self.attention_dropout, self.training)
401
402         A = 1 - aw.sum(dim=1)  # nlt
403
404         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
405         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
406
407         if t0 == 0:
408             V0 = None
409             K0 = None
410         else:
411             V0 = self.rec_v[:, :, t0 - 1]
412             K0 = self.rec_k[:, :, t0 - 1]
413
414         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
415         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
416
417         ######################################################################
418         # compute the readout
419
420         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
421
422         ar = torch.einsum(
423             "nhtd,nltd->nhlt",
424             qr,
425             self.rec_k[:, :, t0:t1],
426         ) / math.sqrt(self.w_qr.size(1))
427
428         ar = ar.softmax(dim=2)  # nhlt
429
430         ar = F.dropout(ar, self.attention_dropout, self.training)
431
432         y = torch.einsum(
433             "nhlt,nltd->nthd",
434             ar,
435             self.rec_v[:, :, t0:t1],
436         ).flatten(2)
437
438         self.cache_y[:, t0:t1] = y @ self.w_o
439
440         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
441
442
443 ##############################
444
445
446 # Returns a tensor with an additional index at rank win_dim, that move
447 # along the same dimension as dim, on a domain {0...win_size-1}, and
448 # dim is restricted on a domain reduced by win_size-1 values.
449
450
451 def moving_window(x, dim, win_dim, win_size):
452     size, stride = x.size(), x.stride()
453     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
454     size = size[:win_dim] + (win_size,) + size[win_dim:]
455     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
456
457     return x.as_strided(size=size, stride=stride)
458
459
460 ##############################
461
462
463 class Caterpillar(nn.Module):
464     def __init__(
465         self,
466         dim_model,
467         dim_qk,
468         dim_v,
469         nb_heads,
470         caterpillar_length,
471         caterpillar_height,
472         attention_dropout=0.0,
473         len_max=1e5,
474     ):
475         super().__init__()
476
477         warnings.warn("Caterpillar", RuntimeWarning)
478
479         def randw(*d, amplitude=None):
480             if amplitude is None:
481                 amplitude = 1 / math.sqrt(d[-1])
482             return nn.Parameter(amplitude * torch.randn(*d))
483
484         self.caterpillar_length = caterpillar_length
485         self.caterpillar_height = caterpillar_height
486         self.attention_dropout = attention_dropout
487
488         self.proba_gate_dropout = 0.0
489
490         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
491         self.b_G = nn.Parameter(
492             torch.full(
493                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
494             )
495         )
496
497         self.w_K = randw(nb_heads, dim_qk, dim_model)
498         self.w_V = randw(nb_heads, dim_v, dim_model)
499         self.w_Q = randw(nb_heads, dim_qk, dim_model)
500         self.w_O = randw(dim_v * nb_heads, dim_model)
501
502         self.init_K_rec = randw(
503             caterpillar_height,
504             caterpillar_length,
505             dim_qk,
506         )
507         self.init_V_rec = randw(
508             caterpillar_height,
509             caterpillar_length,
510             dim_v,
511         )
512
513     def reset_inner_loss(self):
514         self.acc_attention = 0
515         self.acc_nb = 0
516
517     def get_inner_loss(self):
518         # warnings.warn("l2 regularization", RuntimeWarning)
519         # return (self.acc_attention / self.acc_nb).pow(2).sum()
520         return torch.tensor([0], device=self.w_Q.device)
521
522     def forward(self, bs):
523         # Dimensions to make the source a bit clearer, that's needed
524
525         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
526
527         N = bs.x.size(0)
528         T = bs.x.size(1)
529         H = self.w_V.size(0)
530         DV = self.w_V.size(1)
531         DK = self.w_K.size(1)
532         DM = self.w_O.size(1)
533         R = self.caterpillar_height
534         L = self.caterpillar_length
535
536         assert (
537             t0 >= L and (t1 - t0) % L == 0
538         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
539
540         # We cache values to deal efficiently with auto-regression
541
542         if bs.init_cache:
543             self.rec_V = X.new_zeros(N, R, T, DV)
544             self.rec_K = X.new_zeros(N, R, T, DK)
545             # We start the recurrent sequences with optimizable
546             # initial values. No idea if it helps.
547             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
548             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
549
550             self.cache_Y = X.new_zeros(N, T, DM)
551
552         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
553         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
554
555         ######################################################################
556         # Compute the recurrent state
557
558         # This is the Gating sequence that modulates the storing of
559         # the new key and value in the R pairs of the current
560         # stack. There are R independent gating values, which means
561         # that the current K/V may be stored in multiple pairs of the
562         # recurrent state, or not at all.
563
564         G = (
565             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
566         ).sigmoid()
567
568         ######################################################################
569         # Roll the gating indexes
570
571         warnings.warn("rotating barrel", RuntimeWarning)
572
573         r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
574         t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
575         r_barrel = (r_barrel + (t_barrel + t0) // L) % R
576         G = G.gather(dim=2, index=r_barrel.expand_as(G))
577
578         ######################################################################
579         # The "flashbacks"
580
581         if self.training and self.proba_gate_dropout > 0.0:
582             # This is a better implementation of "flashbacks".
583
584             # G is NxHxExT where e is the caterpillar's row.
585
586             warnings.warn("gate dropout", RuntimeWarning)
587             epsilon = 0.5
588
589             dropout_head = (
590                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
591                 .expand_as(G)
592                 .float()
593             )
594
595             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
596
597             dropout_active = (
598                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
599             ).long()
600
601             dropout_head *= dropout_active
602             dropout_tail *= dropout_active
603
604             G = (
605                 G
606                 + dropout_head * (1 - epsilon - G.detach())
607                 - dropout_tail * G.detach()
608             )
609
610         ######################################################################
611
612         # We prepare the arguments for the parallel scan
613
614         # Clip the gating to avoid values greater than 1 when several
615         # heads hit the same row
616
617         G = G / G.sum(1, keepdim=True).clamp(min=1)
618
619         A = 1 - G.sum(1)
620
621         # warnings.warn("harmonic recurrence", RuntimeWarning)
622         # har = torch.arange(t0, t1, device = G.device).float() + 1
623         # A = har / (har + 1)
624         # G = G / har
625
626         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
627         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
628
629         # We start from cached values, which matters in inference
630
631         init_rec_V = self.rec_V[:, :, t0 - L : t0]
632         init_rec_K = self.rec_K[:, :, t0 - L : t0]
633
634         #################################################################
635         # Associative scan
636
637         # Here there is a trick: Since the stack at position t is
638         # computed by updating that at position t-L, the parallel
639         # scan operates with a period of L. To do so we split the
640         # sequence indexing in two axes, the second of size L, and
641         # run the parallel scan using the first as the sequence index.
642
643         A = A.unflatten(2, (-1, L))
644         gated_V = gated_V.unflatten(2, (-1, L))
645         gated_K = gated_K.unflatten(2, (-1, L))
646
647         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
648         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
649
650         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
651         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
652
653         ######################################################################
654         # compute the readout
655
656         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
657
658         # We build tensors NxHxTxFxL where N is the sample index, H
659         # the head, T the time, F the row in the caterpillar, and L
660         # the column in the caterpillar
661
662         windowed_V = moving_window(
663             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
664         )
665
666         windowed_K = moving_window(
667             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
668         )
669
670         # We have an attention score for each of the RxL values
671
672         ar = torch.einsum(
673             "nhtd,nftld->nhtfl",
674             Q,
675             windowed_K,
676         ) / math.sqrt(DK)
677
678         # softmax can operate only on one dimension, hence the
679         # flattening
680
681         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
682
683         ar = F.dropout(ar, self.attention_dropout, self.training)
684
685         # Compute the output for each head, flatten to concatenate
686
687         Y = torch.einsum(
688             "nhtfl,nftld->nthd",
689             ar,
690             windowed_V,
691         ).flatten(2)
692
693         # Compute the final output
694
695         self.cache_Y[:, t0:t1] = Y @ self.w_O
696
697         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
698
699
700 ##############################
701
702
703 class QKVAttention(nn.Module):
704     def __init__(
705         self,
706         dim_model,
707         dim_qk,
708         dim_v,
709         nb_heads=1,
710         causal=False,
711         attention_dropout=0.0,
712     ):
713         super().__init__()
714
715         def randw(*d):
716             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
717
718         self.causal = causal
719         self.attention_dropout = attention_dropout
720         self.record_attention = False
721
722         self.w_q = randw(nb_heads, dim_qk, dim_model)
723         self.w_k = randw(nb_heads, dim_qk, dim_model)
724         self.w_v = randw(nb_heads, dim_v, dim_model)
725         self.w_o = randw(dim_v * nb_heads, dim_model)
726
727     def forward(self, bs):
728         x_q = bs.x
729
730         assert (
731             self.causal or bs.complete()
732         ), "Partial evaluation is only possible for causal models"
733
734         if bs.init_cache:
735             self.cache_k = x_q.new_zeros(
736                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
737             )
738             self.cache_v = x_q.new_zeros(
739                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
740             )
741             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
742
743         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
744
745         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
746             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
747         )
748         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
749             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
750         )
751
752         a = torch.einsum(
753             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
754         ) / math.sqrt(self.w_q.size(1))
755
756         if self.causal:
757             if bs.init_cache:
758                 self.cache_attzero = (
759                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
760                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
761                 )
762             a = a.masked_fill(
763                 self.cache_attzero[
764                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
765                 ],
766                 float("-inf"),
767             )
768
769         a = a.softmax(dim=3)
770
771         if self.record_attention:
772             self.a = a
773
774         a = F.dropout(a, self.attention_dropout, self.training)
775
776         y = torch.einsum(
777             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
778         ).flatten(2)
779
780         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
781
782         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
783
784
785 ##############################
786
787
788 class MyGPT(nn.Module):
789     def __init__(
790         self,
791         vocabulary_size,
792         dim_model,
793         dim_keys,
794         dim_hidden,
795         nb_heads,
796         nb_blocks,
797         nb_lines=None,
798         caterpillar_height=None,
799         causal=False,
800         dropout=0.0,
801         len_max=1e5,
802         attention_layer="kvrec",
803     ):
804         super().__init__()
805
806         assert attention_layer in {
807             "mha",
808             "dumbrec",
809             "kvrec",
810             "caterpillar",
811         }, f"Unknown attention operator {attention_layer}."
812
813         if attention_layer == "caterpillar":
814             assert nb_lines % caterpillar_height == 0
815             self.caterpillar_length = nb_lines // caterpillar_height
816             self.caterpillar_height = caterpillar_height
817         else:
818             self.caterpillar_length = -1
819             self.caterpillar_height = -1
820
821         assert dim_model % nb_heads == 0
822
823         self.embedding = nn.Sequential(
824             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
825             AddPositionalEncoding(len_max),
826         )
827
828         trunk_blocks = []
829
830         def attlayer():
831             if attention_layer == "mha":
832                 return QKVAttention(
833                     dim_model=dim_model,
834                     dim_qk=dim_keys,
835                     dim_v=dim_model // nb_heads,
836                     nb_heads=nb_heads,
837                     causal=causal,
838                     attention_dropout=dropout,
839                 )
840             elif attention_layer == "dumbrec":
841                 return DumbRec(
842                     dim_model=dim_model,
843                     dim_qk=dim_keys,
844                     dim_v=dim_model // nb_heads,
845                     nb_heads=nb_heads,
846                     nb_lines=nb_lines,
847                     attention_dropout=dropout,
848                 )
849             elif attention_layer == "kvrec":
850                 return KVRec(
851                     dim_model=dim_model,
852                     dim_qk=dim_keys,
853                     dim_v=dim_model // nb_heads,
854                     nb_heads=nb_heads,
855                     nb_lines=nb_lines,
856                     attention_dropout=dropout,
857                 )
858             elif attention_layer == "caterpillar":
859                 return Caterpillar(
860                     dim_model=dim_model,
861                     dim_qk=dim_keys,
862                     dim_v=dim_model // nb_heads,
863                     nb_heads=nb_heads,
864                     caterpillar_length=self.caterpillar_length,
865                     caterpillar_height=self.caterpillar_height,
866                     attention_dropout=dropout,
867                 )
868             else:
869                 raise ValueError(f"Unknown attention type {attention_layer}.")
870
871         for b in range(nb_blocks):
872             trunk_blocks += [
873                 WithResidual(
874                     CacheWrapper(nn.LayerNorm((dim_model,))),
875                     attlayer(),
876                 ),
877                 WithResidual(
878                     CacheWrapper(
879                         nn.LayerNorm((dim_model,)),
880                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
881                         nn.ReLU(),
882                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
883                         nn.Dropout(dropout),
884                     ),
885                 ),
886             ]
887
888         self.trunk = nn.Sequential(*trunk_blocks)
889
890         self.readout = CacheWrapper(
891             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
892         )
893
894         with torch.no_grad():
895             for m in self.modules():
896                 if isinstance(m, nn.Embedding):
897                     m.weight.normal_(mean=0, std=2e-2)
898                 elif isinstance(m, nn.LayerNorm):
899                     m.bias.zero_()
900                     m.weight.fill_(1.0)
901
902         self.reset_inner_loss()
903
904     def forward(self, bs):
905         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
906
907         # To make the code simpler in the Caterpillar layer, we pad
908         # here. It's unclear if/how much it hurts computationaly by
909         # increasing the sequence length for the other layers
910
911         if self.caterpillar_length > 0:
912             original_nb = bs.nb
913             if bs.nb % self.caterpillar_length > 0:
914                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
915
916             bs = BracketedSequence(
917                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
918                 bs.first + self.caterpillar_length,
919                 bs.nb,
920                 bs.init_cache,
921             )
922
923         bs = self.embedding(bs)
924         bs = self.trunk(bs)
925         bs = self.readout(bs)
926
927         if self.caterpillar_length > 0:
928             bs = BracketedSequence(
929                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
930                 bs.first - self.caterpillar_length,
931                 original_nb,
932                 bs.init_cache,
933             )
934
935         return bs
936
937     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
938     # 1s where tokens should be generated. The others are kept
939     # unchanged.
940
941     def masked_inplace_autoregression(
942         self,
943         input_src,
944         ar_mask_src,
945         forbidden_tokens=None,
946         deterministic_synthesis=False,
947     ):
948         input = input_src.to(self.readout.f.weight.device)
949         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
950         to_generate = (ar_mask.sum(0) > 0).nonzero()
951         if to_generate.min() > 0:
952             self(
953                 BracketedSequence(input, 0, to_generate.min(), True)
954             )  # Needed to initialize the model's cache
955         for s in range(to_generate.min(), to_generate.max() + 1):
956             output = self(BracketedSequence(input, s, 1, s == 0)).x
957             logits = output[:, s]
958             if forbidden_tokens is not None:
959                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
960             if deterministic_synthesis:
961                 t_next = logits.argmax(1)
962             else:
963                 dist = torch.distributions.categorical.Categorical(logits=logits)
964                 t_next = dist.sample()
965             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
966
967         input_src.copy_(input)
968
969     def reset_inner_loss(self):
970         for m in self.modules():
971             if m is not self and hasattr(m, "reset_inner_loss"):
972                 m.reset_inner_loss()
973
974     def get_inner_loss(self):
975         l = torch.tensor([0.0], device=self.readout.f.weight.device)
976         for m in self.modules():
977             if m is not self and hasattr(m, "get_inner_loss"):
978                 l += m.get_inner_loss()
979         return l
980
981     def record_attention(self, v=True):
982         for m in self.modules():
983             if isinstance(m, QKVAttention):
984                 m.record_attention = v
985
986     def retrieve_attention(self):
987         a = []
988         for m in self.modules():
989             if isinstance(m, QKVAttention):
990                 a.append(m.a)
991         return a
992
993
994 ######################################################################
995
996 if __name__ == "__main__":
997     print("Basic check.")
998
999     m = Caterpillar(
1000         dim_model=4,
1001         dim_qk=3,
1002         dim_v=7,
1003         nb_heads=1,
1004         caterpillar_length=7,
1005         caterpillar_height=3,
1006         attention_dropout=0.0,
1007     )
1008
1009     m.reset_inner_loss()
1010     x = torch.randn(1, 21 + 2 * 7, 4)
1011     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1012     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1013     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1014     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1015     print((y1 - y2).abs().max())
1016     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1017     exit(0)
1018
1019     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1020
1021     vocabulary_size = 128
1022     x = torch.randint(vocabulary_size, (6, 1024))
1023
1024     model = MyGPT(
1025         vocabulary_size=vocabulary_size,
1026         dim_model=512,
1027         dim_keys=64,
1028         dim_hidden=2048,
1029         nb_heads=8,
1030         nb_lines=128,
1031         nb_blocks=12,
1032         dropout=0.1,
1033         causal=True,
1034     )
1035
1036     x = x.to(device)
1037     model.to(device)
1038
1039     import time, sys
1040
1041     # import torchvision.models as models
1042     # from torch.profiler import profile, record_function, ProfilerActivity
1043
1044     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1045     # with record_function("model_inference"):
1046
1047     model.eval()
1048     for i in range(3):
1049         start_time = time.perf_counter()
1050         for k in range(10):
1051             model(BracketedSequence(x))
1052         duration = time.perf_counter() - start_time
1053         print(duration)
1054         sys.stdout.flush()
1055
1056     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1057     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1058
1059     # print("##############################################################")
1060     # y2 = torch.randn_like(y1)
1061     # for s in range(x.size(1)):
1062     # z = model(BracketedSequence(x, s, 1))
1063     # y2[:, s : s + 1] = z.slice()
1064
1065     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1066
1067 ######################################################################