Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193         logger=print,
194         **kwargs,
195     ):
196         super().__init__()
197
198         def randw(*d):
199             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
200
201         self.nb_lines = nb_lines
202         self.attention_dropout = attention_dropout
203
204         self.k_star = randw(nb_lines, dim_qk)
205
206         self.w_qw = randw(nb_heads, dim_qk, dim_model)
207         self.w_qr = randw(nb_heads, dim_qk, dim_model)
208         # self.w_k = randw(nb_heads, dim_qk, dim_model)
209         self.w_v = randw(nb_heads, dim_v, dim_model)
210         self.w_o = randw(dim_v * nb_heads, dim_model)
211
212     def reset_inner_loss(self):
213         self.acc_attention = 0
214         self.acc_nb = 0
215
216     def get_inner_loss(self):
217         warnings.warn("l2 regularization", RuntimeWarning)
218         return (self.acc_attention / self.acc_nb).pow(2).sum()
219         # return torch.tensor([0], device=self.w_qw.device)
220
221     def forward(self, bs):
222         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
223
224         if bs.init_cache:
225             self.rec_v = x_q.new_zeros(
226                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
227             )
228             # self.rec_k = x_q.new_zeros(
229             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
230             # )
231             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
232
233         ######################################################################
234         # Prepare the keys
235
236         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
237
238         warnings.warn("rotating key barrel", RuntimeWarning)
239         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
240         t_barrel = torch.arange(t0, t1, device=k_star.device)
241         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
242         l_barrel = (
243             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
244         ) % k_star.size(0)
245         k_star = k_star[l_barrel, t_barrel]
246
247         ######################################################################
248         # Compute the recurrent state
249
250         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
251
252         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
253         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
254
255         aw = torch.einsum(
256             "nhtd,ltd->nhlt",
257             qw,
258             k_star,
259         ) / math.sqrt(self.w_qw.size(1))
260
261         aw = aw.softmax(dim=2)  # nhlt
262
263         if self.train:
264             self.acc_attention += aw.sum(dim=(0, 1, 3))
265             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
266
267         aw = F.dropout(aw, self.attention_dropout, self.training)
268
269         A = 1 - aw.sum(dim=1)  # nlt
270
271         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
272         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
273
274         if t0 == 0:
275             V0 = None
276             # K0 = None
277         else:
278             V0 = self.rec_v[:, :, t0 - 1]
279             # K0 = self.rec_k[:, :, t0 - 1]
280
281         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
282         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
283
284         ######################################################################
285         # compute the readout
286
287         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288
289         ar = torch.einsum(
290             "nhtd,ld->nhlt",
291             qr,
292             # self.rec_k[:, :, t0:t1],
293             self.k_star,
294         ) / math.sqrt(self.w_qr.size(1))
295
296         ar = ar.softmax(dim=2)  # nhlt
297
298         ar = F.dropout(ar, self.attention_dropout, self.training)
299
300         y = torch.einsum(
301             "nhlt,nltd->nthd",
302             ar,
303             self.rec_v[:, :, t0:t1],
304         ).flatten(2)
305
306         self.cache_y[:, t0:t1] = y @ self.w_o
307
308         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
309
310
311 ##############################
312
313
314 class KVRec(nn.Module):
315     def __init__(
316         self,
317         dim_model,
318         dim_qk,
319         dim_v,
320         nb_heads,
321         nb_lines,
322         attention_dropout=0.0,
323         len_max=1e5,
324         logger=print,
325         **kwargs,
326     ):
327         super().__init__()
328
329         def randw(*d):
330             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
331
332         self.nb_lines = nb_lines
333         self.attention_dropout = attention_dropout
334
335         self.k_star = randw(nb_lines, dim_qk)
336
337         self.w_qw = randw(nb_heads, dim_qk, dim_model)
338         self.w_qr = randw(nb_heads, dim_qk, dim_model)
339         self.w_k = randw(nb_heads, dim_qk, dim_model)
340         self.w_v = randw(nb_heads, dim_v, dim_model)
341         self.w_o = randw(dim_v * nb_heads, dim_model)
342
343     def reset_inner_loss(self):
344         self.acc_attention = 0
345         self.acc_nb = 0
346
347     def get_inner_loss(self):
348         warnings.warn("l2 regularization", RuntimeWarning)
349         return (self.acc_attention / self.acc_nb).pow(2).sum()
350         # return torch.tensor([0], device=self.w_qw.device)
351         # warnings.warn("side regularization", RuntimeWarning)
352         # return (
353         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
354         # )
355         # return torch.tensor([0], device=self.w_qw.device)
356
357     def forward(self, bs):
358         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
359
360         if bs.init_cache:
361             self.rec_v = x_q.new_zeros(
362                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
363             )
364             self.rec_k = x_q.new_zeros(
365                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
366             )
367             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
368
369         ######################################################################
370         # Prepare the keys
371
372         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
373
374         warnings.warn("rotating key barrel", RuntimeWarning)
375         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
376         t_barrel = torch.arange(t0, t1, device=k_star.device)
377         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
378         l_barrel = (
379             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
380         ) % k_star.size(0)
381         k_star = k_star[l_barrel, t_barrel]
382
383         ######################################################################
384         # Compute the recurrent state
385
386         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
387
388         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
389         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
390
391         aw = torch.einsum(
392             "nhtd,ltd->nhlt",
393             qw,
394             k_star,
395         ) / math.sqrt(self.w_qw.size(1))
396
397         aw = aw.softmax(dim=2)  # nhlt
398
399         if self.train:
400             # We want all the memory lines to be used similarly
401             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
402             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
403
404         aw = F.dropout(aw, self.attention_dropout, self.training)
405
406         A = 1 - aw.sum(dim=1)  # nlt
407
408         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
409         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
410
411         if t0 == 0:
412             V0 = None
413             K0 = None
414         else:
415             V0 = self.rec_v[:, :, t0 - 1]
416             K0 = self.rec_k[:, :, t0 - 1]
417
418         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
419         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
420
421         ######################################################################
422         # compute the readout
423
424         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
425
426         ar = torch.einsum(
427             "nhtd,nltd->nhlt",
428             qr,
429             self.rec_k[:, :, t0:t1],
430         ) / math.sqrt(self.w_qr.size(1))
431
432         ar = ar.softmax(dim=2)  # nhlt
433
434         ar = F.dropout(ar, self.attention_dropout, self.training)
435
436         y = torch.einsum(
437             "nhlt,nltd->nthd",
438             ar,
439             self.rec_v[:, :, t0:t1],
440         ).flatten(2)
441
442         self.cache_y[:, t0:t1] = y @ self.w_o
443
444         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
445
446
447 ##############################
448
449
450 # Returns a tensor with an additional index at rank win_dim, that move
451 # along the same dimension as dim, on a domain {0...win_size-1}, and
452 # dim is restricted on a domain reduced by win_size-1 values.
453
454
455 def moving_window(x, dim, win_dim, win_size):
456     size, stride = x.size(), x.stride()
457     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
458     size = size[:win_dim] + (win_size,) + size[win_dim:]
459     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
460
461     return x.as_strided(size=size, stride=stride)
462
463
464 ##############################
465
466
467 class Caterpillar(nn.Module):
468     def __init__(
469         self,
470         dim_model,
471         dim_qk,
472         dim_v,
473         nb_heads,
474         caterpillar_length,
475         caterpillar_height,
476         attention_dropout=0.0,
477         len_max=1e5,
478         logger=print,
479         **kwargs,
480     ):
481         super().__init__()
482
483         warnings.warn("Caterpillar", RuntimeWarning)
484
485         def randw(*d, amplitude=None):
486             if amplitude is None:
487                 amplitude = 1 / math.sqrt(d[-1])
488             return nn.Parameter(amplitude * torch.randn(*d))
489
490         self.caterpillar_length = caterpillar_length
491         self.caterpillar_height = caterpillar_height
492         self.attention_dropout = attention_dropout
493
494         ######################################################################
495         # sup_args
496
497         x = kwargs.get("gate_dropout")
498         if x is None:
499             self.proba_gate_dropout = 0.0
500         else:
501             self.proba_gate_dropout = float(x)
502
503         logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
504
505         x = kwargs.get("default_bg")
506         if x is None:
507             default_bg = -math.log(caterpillar_height - 1)
508         else:
509             default_bg = float(x)
510
511         logger(f"default_bg {default_bg}")
512
513         ######################################################################
514
515         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
516         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
517
518         self.w_K = randw(nb_heads, dim_qk, dim_model)
519         self.w_V = randw(nb_heads, dim_v, dim_model)
520         self.w_Q = randw(nb_heads, dim_qk, dim_model)
521         self.w_O = randw(dim_v * nb_heads, dim_model)
522
523         self.init_K_rec = randw(
524             caterpillar_height,
525             caterpillar_length,
526             dim_qk,
527         )
528         self.init_V_rec = randw(
529             caterpillar_height,
530             caterpillar_length,
531             dim_v,
532         )
533
534     def reset_inner_loss(self):
535         self.acc_attention = 0
536         self.acc_nb = 0
537
538     def get_inner_loss(self):
539         # warnings.warn("l2 regularization", RuntimeWarning)
540         # return (self.acc_attention / self.acc_nb).pow(2).sum()
541         return torch.tensor([0], device=self.w_Q.device)
542
543     def forward(self, bs):
544         # Dimensions to make the source a bit clearer, that's needed
545
546         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
547
548         N = bs.x.size(0)
549         T = bs.x.size(1)
550         H = self.w_V.size(0)
551         DV = self.w_V.size(1)
552         DK = self.w_K.size(1)
553         DM = self.w_O.size(1)
554         R = self.caterpillar_height
555         L = self.caterpillar_length
556
557         assert (
558             t0 >= L and (t1 - t0) % L == 0
559         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
560
561         # We cache values to deal efficiently with auto-regression
562
563         if bs.init_cache:
564             self.rec_V = X.new_zeros(N, R, T, DV)
565             self.rec_K = X.new_zeros(N, R, T, DK)
566             # We start the recurrent sequences with optimizable
567             # initial values. No idea if it helps.
568             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
569             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
570
571             self.cache_Y = X.new_zeros(N, T, DM)
572
573         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
574         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
575
576         ######################################################################
577         # Compute the recurrent state
578
579         # This is the Gating sequence that modulates the storing of
580         # the new key and value in the R pairs of the current
581         # stack. There are R independent gating values, which means
582         # that the current K/V may be stored in multiple pairs of the
583         # recurrent state, or not at all.
584
585         G = (
586             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
587         ).sigmoid()
588
589         # warnings.warn("softmax gating", RuntimeWarning)
590
591         # G = (
592         # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
593         # ).softmax(dim=2)
594
595         ######################################################################
596         # The "flashbacks"
597
598         if self.training and self.proba_gate_dropout > 0.0:
599             # This is a better implementation of "flashbacks".
600
601             # G is NxHxExT where e is the caterpillar's row.
602
603             warnings.warn("gate dropout", RuntimeWarning)
604
605             kill = (
606                 torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
607             ).float()
608
609             alpha = G / (1 - self.proba_gate_dropout)
610
611             G = alpha * (1 - kill)
612
613         ######################################################################
614         # Clip the gating to avoid values greater than 1 when several
615         # heads hit the same row
616
617         G = G / G.sum(1, keepdim=True).clamp(min=1)
618
619         ######################################################################
620         # Roll the gating indexes
621
622         # warnings.warn("rotating barrel", RuntimeWarning)
623
624         # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
625         # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
626         # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
627         # G = G.gather(dim=2, index=r_barrel.expand_as(G))
628
629         # We prepare the arguments for the parallel scan
630
631         A = 1 - G.sum(1)
632
633         # warnings.warn("harmonic recurrence", RuntimeWarning)
634         # har = torch.arange(t0, t1, device = G.device).float() + 1
635         # A = har / (har + 1)
636         # G = G / har
637
638         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
639         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
640
641         # We start from cached values, which matters in inference
642
643         init_rec_V = self.rec_V[:, :, t0 - L : t0]
644         init_rec_K = self.rec_K[:, :, t0 - L : t0]
645
646         #################################################################
647         # Associative scan
648
649         # Here there is a trick: Since the stack at position t is
650         # computed by updating that at position t-L, the parallel
651         # scan operates with a period of L. To do so we split the
652         # sequence indexing in two axes, the second of size L, and
653         # run the parallel scan using the first as the sequence index.
654
655         A = A.unflatten(2, (-1, L))
656         gated_V = gated_V.unflatten(2, (-1, L))
657         gated_K = gated_K.unflatten(2, (-1, L))
658
659         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
660         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
661
662         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
663         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
664
665         ######################################################################
666         # compute the readout
667
668         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
669
670         # We build tensors NxHxTxFxL where N is the sample index, H
671         # the head, T the time, F the row in the caterpillar, and L
672         # the column in the caterpillar
673
674         windowed_V = moving_window(
675             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
676         )
677
678         windowed_K = moving_window(
679             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
680         )
681
682         # We have an attention score for each of the RxL values
683
684         ar = torch.einsum(
685             "nhtd,nftld->nhtfl",
686             Q,
687             windowed_K,
688         ) / math.sqrt(DK)
689
690         # softmax can operate only on one dimension, hence the
691         # flattening
692
693         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
694
695         ar = F.dropout(ar, self.attention_dropout, self.training)
696
697         # Compute the output for each head, flatten to concatenate
698
699         Y = torch.einsum(
700             "nhtfl,nftld->nthd",
701             ar,
702             windowed_V,
703         ).flatten(2)
704
705         # Compute the final output
706
707         self.cache_Y[:, t0:t1] = Y @ self.w_O
708
709         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
710
711
712 ##############################
713
714
715 class QKVAttention(nn.Module):
716     def __init__(
717         self,
718         dim_model,
719         dim_qk,
720         dim_v,
721         nb_heads=1,
722         causal=False,
723         attention_dropout=0.0,
724         logger=print,
725         **kwargs,
726     ):
727         super().__init__()
728
729         def randw(*d):
730             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
731
732         self.causal = causal
733         self.attention_dropout = attention_dropout
734         self.record_attention = False
735
736         self.w_q = randw(nb_heads, dim_qk, dim_model)
737         self.w_k = randw(nb_heads, dim_qk, dim_model)
738         self.w_v = randw(nb_heads, dim_v, dim_model)
739         self.w_o = randw(dim_v * nb_heads, dim_model)
740
741     def forward(self, bs):
742         x_q = bs.x
743
744         assert (
745             self.causal or bs.complete()
746         ), "Partial evaluation is only possible for causal models"
747
748         if bs.init_cache:
749             self.cache_k = x_q.new_zeros(
750                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
751             )
752             self.cache_v = x_q.new_zeros(
753                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
754             )
755             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
756
757         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
758
759         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
760             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
761         )
762         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
763             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
764         )
765
766         a = torch.einsum(
767             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
768         ) / math.sqrt(self.w_q.size(1))
769
770         if self.causal:
771             if bs.init_cache:
772                 self.cache_attzero = (
773                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
774                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
775                 )
776             a = a.masked_fill(
777                 self.cache_attzero[
778                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
779                 ],
780                 float("-inf"),
781             )
782
783         a = a.softmax(dim=3)
784
785         if self.record_attention:
786             self.a = a
787
788         a = F.dropout(a, self.attention_dropout, self.training)
789
790         y = torch.einsum(
791             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
792         ).flatten(2)
793
794         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
795
796         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
797
798
799 ##############################
800
801
802 class MyGPT(nn.Module):
803     def __init__(
804         self,
805         vocabulary_size,
806         dim_model,
807         dim_keys,
808         dim_hidden,
809         nb_heads,
810         nb_blocks,
811         nb_lines=None,
812         caterpillar_height=None,
813         causal=False,
814         dropout=0.0,
815         len_max=1e5,
816         attention_layer="kvrec",
817         logger=print,
818         **kwargs,
819     ):
820         super().__init__()
821
822         assert attention_layer in {
823             "mha",
824             "dumbrec",
825             "kvrec",
826             "caterpillar",
827         }, f"Unknown attention operator {attention_layer}."
828
829         if attention_layer == "caterpillar":
830             assert nb_lines % caterpillar_height == 0
831             self.caterpillar_length = nb_lines // caterpillar_height
832             self.caterpillar_height = caterpillar_height
833         else:
834             self.caterpillar_length = -1
835             self.caterpillar_height = -1
836
837         assert dim_model % nb_heads == 0
838
839         self.embedding = nn.Sequential(
840             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
841             AddPositionalEncoding(len_max),
842         )
843
844         trunk_blocks = []
845
846         def attlayer():
847             if attention_layer == "mha":
848                 return QKVAttention(
849                     dim_model=dim_model,
850                     dim_qk=dim_keys,
851                     dim_v=dim_model // nb_heads,
852                     nb_heads=nb_heads,
853                     causal=causal,
854                     attention_dropout=dropout,
855                     logger=logger,
856                     **kwargs,
857                 )
858             elif attention_layer == "dumbrec":
859                 return DumbRec(
860                     dim_model=dim_model,
861                     dim_qk=dim_keys,
862                     dim_v=dim_model // nb_heads,
863                     nb_heads=nb_heads,
864                     nb_lines=nb_lines,
865                     attention_dropout=dropout,
866                     logger=logger,
867                     **kwargs,
868                 )
869             elif attention_layer == "kvrec":
870                 return KVRec(
871                     dim_model=dim_model,
872                     dim_qk=dim_keys,
873                     dim_v=dim_model // nb_heads,
874                     nb_heads=nb_heads,
875                     nb_lines=nb_lines,
876                     attention_dropout=dropout,
877                     logger=logger,
878                     **kwargs,
879                 )
880             elif attention_layer == "caterpillar":
881                 return Caterpillar(
882                     dim_model=dim_model,
883                     dim_qk=dim_keys,
884                     dim_v=dim_model // nb_heads,
885                     nb_heads=nb_heads,
886                     caterpillar_length=self.caterpillar_length,
887                     caterpillar_height=self.caterpillar_height,
888                     attention_dropout=dropout,
889                     logger=logger,
890                     **kwargs,
891                 )
892             else:
893                 raise ValueError(f"Unknown attention type {attention_layer}.")
894
895         for b in range(nb_blocks):
896             trunk_blocks += [
897                 WithResidual(
898                     CacheWrapper(nn.LayerNorm((dim_model,))),
899                     attlayer(),
900                 ),
901                 WithResidual(
902                     CacheWrapper(
903                         nn.LayerNorm((dim_model,)),
904                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
905                         nn.ReLU(),
906                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
907                         nn.Dropout(dropout),
908                     ),
909                 ),
910             ]
911
912         self.trunk = nn.Sequential(*trunk_blocks)
913
914         self.readout = CacheWrapper(
915             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
916         )
917
918         with torch.no_grad():
919             for m in self.modules():
920                 if isinstance(m, nn.Embedding):
921                     m.weight.normal_(mean=0, std=2e-2)
922                 elif isinstance(m, nn.LayerNorm):
923                     m.bias.zero_()
924                     m.weight.fill_(1.0)
925
926         self.reset_inner_loss()
927
928     def forward(self, bs):
929         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
930
931         # To make the code simpler in the Caterpillar layer, we pad
932         # here. It's unclear if/how much it hurts computationaly by
933         # increasing the sequence length for the other layers
934
935         if self.caterpillar_length > 0:
936             original_nb = bs.nb
937             if bs.nb % self.caterpillar_length > 0:
938                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
939
940             bs = BracketedSequence(
941                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
942                 bs.first + self.caterpillar_length,
943                 bs.nb,
944                 bs.init_cache,
945             )
946
947         bs = self.embedding(bs)
948         bs = self.trunk(bs)
949         bs = self.readout(bs)
950
951         if self.caterpillar_length > 0:
952             bs = BracketedSequence(
953                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
954                 bs.first - self.caterpillar_length,
955                 original_nb,
956                 bs.init_cache,
957             )
958
959         return bs
960
961     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
962     # 1s where tokens should be generated. The others are kept
963     # unchanged.
964
965     def masked_inplace_autoregression(
966         self,
967         input_src,
968         ar_mask_src,
969         forbidden_tokens=None,
970         deterministic_synthesis=False,
971     ):
972         input = input_src.to(self.readout.f.weight.device)
973         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
974         to_generate = (ar_mask.sum(0) > 0).nonzero()
975         if to_generate.min() > 0:
976             self(
977                 BracketedSequence(input, 0, to_generate.min(), True)
978             )  # Needed to initialize the model's cache
979         for s in range(to_generate.min(), to_generate.max() + 1):
980             output = self(BracketedSequence(input, s, 1, s == 0)).x
981             logits = output[:, s]
982             if forbidden_tokens is not None:
983                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
984             if deterministic_synthesis:
985                 t_next = logits.argmax(1)
986             else:
987                 dist = torch.distributions.categorical.Categorical(logits=logits)
988                 t_next = dist.sample()
989             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
990
991         input_src.copy_(input)
992
993     def reset_inner_loss(self):
994         for m in self.modules():
995             if m is not self and hasattr(m, "reset_inner_loss"):
996                 m.reset_inner_loss()
997
998     def get_inner_loss(self):
999         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1000         for m in self.modules():
1001             if m is not self and hasattr(m, "get_inner_loss"):
1002                 l += m.get_inner_loss()
1003         return l
1004
1005     def record_attention(self, v=True):
1006         for m in self.modules():
1007             if isinstance(m, QKVAttention):
1008                 m.record_attention = v
1009
1010     def retrieve_attention(self):
1011         a = []
1012         for m in self.modules():
1013             if isinstance(m, QKVAttention):
1014                 a.append(m.a)
1015         return a
1016
1017
1018 ######################################################################
1019
1020 if __name__ == "__main__":
1021     print("Basic check.")
1022
1023     m = Caterpillar(
1024         dim_model=4,
1025         dim_qk=3,
1026         dim_v=7,
1027         nb_heads=1,
1028         caterpillar_length=7,
1029         caterpillar_height=3,
1030         attention_dropout=0.0,
1031     )
1032
1033     m.reset_inner_loss()
1034     x = torch.randn(1, 21 + 2 * 7, 4)
1035     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1036     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1037     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1038     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1039     print((y1 - y2).abs().max())
1040     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1041     exit(0)
1042
1043     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1044
1045     vocabulary_size = 128
1046     x = torch.randint(vocabulary_size, (6, 1024))
1047
1048     model = MyGPT(
1049         vocabulary_size=vocabulary_size,
1050         dim_model=512,
1051         dim_keys=64,
1052         dim_hidden=2048,
1053         nb_heads=8,
1054         nb_lines=128,
1055         nb_blocks=12,
1056         dropout=0.1,
1057         causal=True,
1058     )
1059
1060     x = x.to(device)
1061     model.to(device)
1062
1063     import time, sys
1064
1065     # import torchvision.models as models
1066     # from torch.profiler import profile, record_function, ProfilerActivity
1067
1068     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1069     # with record_function("model_inference"):
1070
1071     model.eval()
1072     for i in range(3):
1073         start_time = time.perf_counter()
1074         for k in range(10):
1075             model(BracketedSequence(x))
1076         duration = time.perf_counter() - start_time
1077         print(duration)
1078         sys.stdout.flush()
1079
1080     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1081     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1082
1083     # print("##############################################################")
1084     # y2 = torch.randn_like(y1)
1085     # for s in range(x.size(1)):
1086     # z = model(BracketedSequence(x, s, 1))
1087     # y2[:, s : s + 1] = z.slice()
1088
1089     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1090
1091 ######################################################################