Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 # This implementation is equipped with RNN layers to replace the MHA
14
15 import math, warnings
16
17 import torch, einops
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 import ffutils
23
24 # import memload
25
26 ######################################################################
27
28 # A BracketedSequence is a BxTx... tensor with a first and a nb time
29 # steps to compute.
30
31 # Modules able to process it expect that they will have to process a
32 # first bracket starting at t=0, followed by a succession of brackets
33 # that move forward in time, do not overlap, and cover the axis T with
34 # no holes.
35 #
36 # Although it is more general, for a classical prompt-conditioned
37 # auto-regressive process it will be a first bracket starting at 0 and
38 # of arbitrary length for the "prompt", followed by brackets of length
39 # 1 for the successive tokens.
40 #
41 # Modules able to process brackets may implement a cache that is
42 # resetted when init_cache is True
43
44
45 class BracketedSequence:
46     def __init__(self, x, first=None, nb=None, init_cache=None):
47         self.x = x
48         assert (first is None and nb is None and init_cache is None) or (
49             first is not None and nb is not None and init_cache is not None
50         )
51
52         self.first = 0 if first is None else first
53         self.nb = x.size(1) if nb is None else nb
54         self.init_cache = True if init_cache is None else init_cache
55
56     def slice(self):
57         return self.x[:, self.first : self.first + self.nb]
58
59     def complete(self):
60         return self.first == 0 and self.nb == self.x.size(1)
61
62
63 ######################################################################
64
65
66 class CacheWrapper(nn.Module):
67     def __init__(self, *f):
68         super().__init__()
69         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70
71     def forward(self, bs):
72         if bs.init_cache:
73             y = self.f(bs.slice())
74             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
75             self.cache_y[:, bs.first : bs.first + bs.nb] = y
76         else:
77             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
78             assert bs.first + bs.nb <= self.cache_y.size(1)
79             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80
81         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
82
83
84 ##############################
85
86
87 class WithResidual(nn.Module):
88     def __init__(self, *f):
89         super().__init__()
90         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91
92     def forward(self, bs):
93         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
94
95
96 ##############################
97
98
99 class AddPositionalEncoding(nn.Module):
100     def __init__(self, len_max):
101         super().__init__()
102         self.len_max = len_max
103
104     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105
106     def forward(self, bs):
107         if bs.init_cache:
108             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
109                 :, None
110             ]
111             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
112                 None, :
113             ]
114             k = j % 2
115             self.pe = torch.sin(
116                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117             )
118             self.cache_y = bs.x.new(bs.x.size())
119
120         self.cache_y[:, bs.first : bs.first + bs.nb] = (
121             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
122         )
123
124         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
125
126
127 import pscan
128
129
130 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
131
132
133 def pscan_dim(A, X, Y_init, dim=-2):
134     s = X.size()
135     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136
137     A = A.reshape(a, T, *s[dim + 1 : -1])
138     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
139
140     if Y_init is None:
141         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142     else:
143         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144
145     Y = pscan.pscan(A, X, Y_init).reshape(s)
146
147     return Y
148
149
150 def pscan_shape(A, X, Y_init):
151     s = X.size()
152     A = A.reshape(-1, s[-2])
153     X = X.reshape(-1, s[-2], s[-1])
154
155     if Y_init is None:
156         Y_init = X.new_zeros(X.size(0), s[-1])
157     else:
158         Y_init = Y_init.reshape(-1, s[-1])
159
160     Y = pscan.pscan(A, X, Y_init).reshape(s)
161
162     return Y
163
164
165 def nsum_shape(X, Y_init):
166     s = X.size()
167     X = X.reshape(-1, s[-2], s[-1])  # ntd
168
169     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
170     result = []
171
172     for k in range(X.size(1)):
173         Y = Y + X[:, k]
174         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
175         result.append(Y)
176
177     return torch.cat(result, dim=1).reshape(s)
178
179
180 ##############################
181
182
183 class DumbRec(nn.Module):
184     def __init__(
185         self,
186         dim_model,
187         dim_qk,
188         dim_v,
189         nb_heads,
190         nb_lines,
191         attention_dropout=0.0,
192         len_max=1e5,
193         logger=print,
194         **kwargs,
195     ):
196         super().__init__()
197
198         def randw(*d):
199             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
200
201         self.nb_lines = nb_lines
202         self.attention_dropout = attention_dropout
203
204         self.k_star = randw(nb_lines, dim_qk)
205
206         self.w_qw = randw(nb_heads, dim_qk, dim_model)
207         self.w_qr = randw(nb_heads, dim_qk, dim_model)
208         # self.w_k = randw(nb_heads, dim_qk, dim_model)
209         self.w_v = randw(nb_heads, dim_v, dim_model)
210         self.w_o = randw(dim_v * nb_heads, dim_model)
211
212     def reset_inner_loss(self):
213         self.acc_attention = 0
214         self.acc_nb = 0
215
216     def get_inner_loss(self):
217         warnings.warn("l2 regularization", RuntimeWarning)
218         return (self.acc_attention / self.acc_nb).pow(2).sum()
219         # return torch.tensor([0], device=self.w_qw.device)
220
221     def forward(self, bs):
222         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
223
224         if bs.init_cache:
225             self.rec_v = x_q.new_zeros(
226                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
227             )
228             # self.rec_k = x_q.new_zeros(
229             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
230             # )
231             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
232
233         ######################################################################
234         # Prepare the keys
235
236         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
237
238         warnings.warn("rotating key barrel", RuntimeWarning)
239         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
240         t_barrel = torch.arange(t0, t1, device=k_star.device)
241         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
242         l_barrel = (
243             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
244         ) % k_star.size(0)
245         k_star = k_star[l_barrel, t_barrel]
246
247         ######################################################################
248         # Compute the recurrent state
249
250         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
251
252         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
253         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
254
255         aw = torch.einsum(
256             "nhtd,ltd->nhlt",
257             qw,
258             k_star,
259         ) / math.sqrt(self.w_qw.size(1))
260
261         aw = aw.softmax(dim=2)  # nhlt
262
263         if self.train:
264             self.acc_attention += aw.sum(dim=(0, 1, 3))
265             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
266
267         aw = F.dropout(aw, self.attention_dropout, self.training)
268
269         A = 1 - aw.sum(dim=1)  # nlt
270
271         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
272         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
273
274         if t0 == 0:
275             V0 = None
276             # K0 = None
277         else:
278             V0 = self.rec_v[:, :, t0 - 1]
279             # K0 = self.rec_k[:, :, t0 - 1]
280
281         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
282         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
283
284         ######################################################################
285         # compute the readout
286
287         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
288
289         ar = torch.einsum(
290             "nhtd,ld->nhlt",
291             qr,
292             # self.rec_k[:, :, t0:t1],
293             self.k_star,
294         ) / math.sqrt(self.w_qr.size(1))
295
296         ar = ar.softmax(dim=2)  # nhlt
297
298         ar = F.dropout(ar, self.attention_dropout, self.training)
299
300         y = torch.einsum(
301             "nhlt,nltd->nthd",
302             ar,
303             self.rec_v[:, :, t0:t1],
304         ).flatten(2)
305
306         self.cache_y[:, t0:t1] = y @ self.w_o
307
308         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
309
310
311 ##############################
312
313
314 class KVRec(nn.Module):
315     def __init__(
316         self,
317         dim_model,
318         dim_qk,
319         dim_v,
320         nb_heads,
321         nb_lines,
322         attention_dropout=0.0,
323         len_max=1e5,
324         logger=print,
325         **kwargs,
326     ):
327         super().__init__()
328
329         def randw(*d):
330             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
331
332         self.nb_lines = nb_lines
333         self.attention_dropout = attention_dropout
334
335         self.k_star = randw(nb_lines, dim_qk)
336
337         self.w_qw = randw(nb_heads, dim_qk, dim_model)
338         self.w_qr = randw(nb_heads, dim_qk, dim_model)
339         self.w_k = randw(nb_heads, dim_qk, dim_model)
340         self.w_v = randw(nb_heads, dim_v, dim_model)
341         self.w_o = randw(dim_v * nb_heads, dim_model)
342
343     def reset_inner_loss(self):
344         self.acc_attention = 0
345         self.acc_nb = 0
346
347     def get_inner_loss(self):
348         warnings.warn("l2 regularization", RuntimeWarning)
349         return (self.acc_attention / self.acc_nb).pow(2).sum()
350         # return torch.tensor([0], device=self.w_qw.device)
351         # warnings.warn("side regularization", RuntimeWarning)
352         # return (
353         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
354         # )
355         # return torch.tensor([0], device=self.w_qw.device)
356
357     def forward(self, bs):
358         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
359
360         if bs.init_cache:
361             self.rec_v = x_q.new_zeros(
362                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
363             )
364             self.rec_k = x_q.new_zeros(
365                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
366             )
367             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
368
369         ######################################################################
370         # Prepare the keys
371
372         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
373
374         warnings.warn("rotating key barrel", RuntimeWarning)
375         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
376         t_barrel = torch.arange(t0, t1, device=k_star.device)
377         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
378         l_barrel = (
379             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
380         ) % k_star.size(0)
381         k_star = k_star[l_barrel, t_barrel]
382
383         ######################################################################
384         # Compute the recurrent state
385
386         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
387
388         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
389         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
390
391         aw = torch.einsum(
392             "nhtd,ltd->nhlt",
393             qw,
394             k_star,
395         ) / math.sqrt(self.w_qw.size(1))
396
397         aw = aw.softmax(dim=2)  # nhlt
398
399         if self.train:
400             # We want all the memory lines to be used similarly
401             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
402             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
403
404         aw = F.dropout(aw, self.attention_dropout, self.training)
405
406         A = 1 - aw.sum(dim=1)  # nlt
407
408         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
409         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
410
411         if t0 == 0:
412             V0 = None
413             K0 = None
414         else:
415             V0 = self.rec_v[:, :, t0 - 1]
416             K0 = self.rec_k[:, :, t0 - 1]
417
418         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
419         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
420
421         ######################################################################
422         # compute the readout
423
424         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
425
426         ar = torch.einsum(
427             "nhtd,nltd->nhlt",
428             qr,
429             self.rec_k[:, :, t0:t1],
430         ) / math.sqrt(self.w_qr.size(1))
431
432         ar = ar.softmax(dim=2)  # nhlt
433
434         ar = F.dropout(ar, self.attention_dropout, self.training)
435
436         y = torch.einsum(
437             "nhlt,nltd->nthd",
438             ar,
439             self.rec_v[:, :, t0:t1],
440         ).flatten(2)
441
442         self.cache_y[:, t0:t1] = y @ self.w_o
443
444         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
445
446
447 ##############################
448
449
450 # Returns a tensor with an additional index at rank win_dim, that move
451 # along the same dimension as dim, on a domain {0...win_size-1}, and
452 # dim is restricted on a domain reduced by win_size-1 values.
453
454
455 def moving_window(x, dim, win_dim, win_size):
456     size, stride = x.size(), x.stride()
457     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
458     size = size[:win_dim] + (win_size,) + size[win_dim:]
459     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
460
461     return x.as_strided(size=size, stride=stride)
462
463
464 ##############################
465
466
467 class Caterpillar(nn.Module):
468     def __init__(
469         self,
470         dim_model,
471         dim_qk,
472         dim_v,
473         nb_heads,
474         caterpillar_length,
475         caterpillar_height,
476         attention_dropout=0.0,
477         len_max=1e5,
478         logger=print,
479         **kwargs,
480     ):
481         super().__init__()
482
483         warnings.warn("Caterpillar", RuntimeWarning)
484
485         def randw(*d, amplitude=None):
486             if amplitude is None:
487                 amplitude = 1 / math.sqrt(d[-1])
488             return nn.Parameter(amplitude * torch.randn(*d))
489
490         self.caterpillar_length = caterpillar_length
491         self.caterpillar_height = caterpillar_height
492         self.attention_dropout = attention_dropout
493
494         self.proba_gate_dropout = 0.0
495
496         default_b_G = kwargs.get("default_b_G")
497         if default_b_G is None:
498             default_b_G = -math.log(caterpillar_height - 1)
499
500         logger(f"default_b_G {default_b_G}")
501
502         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
503         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G))
504
505         self.w_K = randw(nb_heads, dim_qk, dim_model)
506         self.w_V = randw(nb_heads, dim_v, dim_model)
507         self.w_Q = randw(nb_heads, dim_qk, dim_model)
508         self.w_O = randw(dim_v * nb_heads, dim_model)
509
510         self.init_K_rec = randw(
511             caterpillar_height,
512             caterpillar_length,
513             dim_qk,
514         )
515         self.init_V_rec = randw(
516             caterpillar_height,
517             caterpillar_length,
518             dim_v,
519         )
520
521     def reset_inner_loss(self):
522         self.acc_attention = 0
523         self.acc_nb = 0
524
525     def get_inner_loss(self):
526         # warnings.warn("l2 regularization", RuntimeWarning)
527         # return (self.acc_attention / self.acc_nb).pow(2).sum()
528         return torch.tensor([0], device=self.w_Q.device)
529
530     def forward(self, bs):
531         # Dimensions to make the source a bit clearer, that's needed
532
533         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
534
535         N = bs.x.size(0)
536         T = bs.x.size(1)
537         H = self.w_V.size(0)
538         DV = self.w_V.size(1)
539         DK = self.w_K.size(1)
540         DM = self.w_O.size(1)
541         R = self.caterpillar_height
542         L = self.caterpillar_length
543
544         assert (
545             t0 >= L and (t1 - t0) % L == 0
546         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
547
548         # We cache values to deal efficiently with auto-regression
549
550         if bs.init_cache:
551             self.rec_V = X.new_zeros(N, R, T, DV)
552             self.rec_K = X.new_zeros(N, R, T, DK)
553             # We start the recurrent sequences with optimizable
554             # initial values. No idea if it helps.
555             self.rec_V[:, :, t0 - L : t0] = self.init_V_rec[None, :, :, :]
556             self.rec_K[:, :, t0 - L : t0] = self.init_K_rec[None, :, :, :]
557
558             self.cache_Y = X.new_zeros(N, T, DM)
559
560         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
561         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
562
563         ######################################################################
564         # Compute the recurrent state
565
566         # This is the Gating sequence that modulates the storing of
567         # the new key and value in the R pairs of the current
568         # stack. There are R independent gating values, which means
569         # that the current K/V may be stored in multiple pairs of the
570         # recurrent state, or not at all.
571
572         G = (
573             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
574         ).sigmoid()
575
576         # Clip the gating to avoid values greater than 1 when several
577         # heads hit the same row
578
579         G = G / G.sum(1, keepdim=True).clamp(min=1)
580
581         ######################################################################
582         # Roll the gating indexes
583
584         # warnings.warn("rotating barrel", RuntimeWarning)
585
586         # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
587         # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
588         # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
589         # G = G.gather(dim=2, index=r_barrel.expand_as(G))
590
591         ######################################################################
592         # The "flashbacks"
593
594         if self.training and self.proba_gate_dropout > 0.0:
595             # This is a better implementation of "flashbacks".
596
597             # G is NxHxExT where e is the caterpillar's row.
598
599             warnings.warn("gate dropout", RuntimeWarning)
600             epsilon = 0.5
601
602             dropout_head = (
603                 (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
604                 .expand_as(G)
605                 .float()
606             )
607
608             dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
609
610             dropout_active = (
611                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
612             ).long()
613
614             dropout_head *= dropout_active
615             dropout_tail *= dropout_active
616
617             G = (
618                 G
619                 + dropout_head * (1 - epsilon - G.detach())
620                 - dropout_tail * G.detach()
621             )
622
623         ######################################################################
624
625         # We prepare the arguments for the parallel scan
626
627         A = 1 - G.sum(1)
628
629         # warnings.warn("harmonic recurrence", RuntimeWarning)
630         # har = torch.arange(t0, t1, device = G.device).float() + 1
631         # A = har / (har + 1)
632         # G = G / har
633
634         gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
635         gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
636
637         # We start from cached values, which matters in inference
638
639         init_rec_V = self.rec_V[:, :, t0 - L : t0]
640         init_rec_K = self.rec_K[:, :, t0 - L : t0]
641
642         #################################################################
643         # Associative scan
644
645         # Here there is a trick: Since the stack at position t is
646         # computed by updating that at position t-L, the parallel
647         # scan operates with a period of L. To do so we split the
648         # sequence indexing in two axes, the second of size L, and
649         # run the parallel scan using the first as the sequence index.
650
651         A = A.unflatten(2, (-1, L))
652         gated_V = gated_V.unflatten(2, (-1, L))
653         gated_K = gated_K.unflatten(2, (-1, L))
654
655         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
656         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
657
658         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
659         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
660
661         ######################################################################
662         # compute the readout
663
664         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
665
666         # We build tensors NxHxTxFxL where N is the sample index, H
667         # the head, T the time, F the row in the caterpillar, and L
668         # the column in the caterpillar
669
670         windowed_V = moving_window(
671             self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
672         )
673
674         windowed_K = moving_window(
675             self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L
676         )
677
678         # We have an attention score for each of the RxL values
679
680         ar = torch.einsum(
681             "nhtd,nftld->nhtfl",
682             Q,
683             windowed_K,
684         ) / math.sqrt(DK)
685
686         # softmax can operate only on one dimension, hence the
687         # flattening
688
689         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
690
691         ar = F.dropout(ar, self.attention_dropout, self.training)
692
693         # Compute the output for each head, flatten to concatenate
694
695         Y = torch.einsum(
696             "nhtfl,nftld->nthd",
697             ar,
698             windowed_V,
699         ).flatten(2)
700
701         # Compute the final output
702
703         self.cache_Y[:, t0:t1] = Y @ self.w_O
704
705         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
706
707
708 ##############################
709
710
711 class QKVAttention(nn.Module):
712     def __init__(
713         self,
714         dim_model,
715         dim_qk,
716         dim_v,
717         nb_heads=1,
718         causal=False,
719         attention_dropout=0.0,
720         logger=print,
721         **kwargs,
722     ):
723         super().__init__()
724
725         def randw(*d):
726             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
727
728         self.causal = causal
729         self.attention_dropout = attention_dropout
730         self.record_attention = False
731
732         self.w_q = randw(nb_heads, dim_qk, dim_model)
733         self.w_k = randw(nb_heads, dim_qk, dim_model)
734         self.w_v = randw(nb_heads, dim_v, dim_model)
735         self.w_o = randw(dim_v * nb_heads, dim_model)
736
737     def forward(self, bs):
738         x_q = bs.x
739
740         assert (
741             self.causal or bs.complete()
742         ), "Partial evaluation is only possible for causal models"
743
744         if bs.init_cache:
745             self.cache_k = x_q.new_zeros(
746                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
747             )
748             self.cache_v = x_q.new_zeros(
749                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
750             )
751             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
752
753         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
754
755         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
756             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
757         )
758         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
759             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
760         )
761
762         a = torch.einsum(
763             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
764         ) / math.sqrt(self.w_q.size(1))
765
766         if self.causal:
767             if bs.init_cache:
768                 self.cache_attzero = (
769                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
770                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
771                 )
772             a = a.masked_fill(
773                 self.cache_attzero[
774                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
775                 ],
776                 float("-inf"),
777             )
778
779         a = a.softmax(dim=3)
780
781         if self.record_attention:
782             self.a = a
783
784         a = F.dropout(a, self.attention_dropout, self.training)
785
786         y = torch.einsum(
787             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
788         ).flatten(2)
789
790         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
791
792         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
793
794
795 ##############################
796
797
798 class MyGPT(nn.Module):
799     def __init__(
800         self,
801         vocabulary_size,
802         dim_model,
803         dim_keys,
804         dim_hidden,
805         nb_heads,
806         nb_blocks,
807         nb_lines=None,
808         caterpillar_height=None,
809         causal=False,
810         dropout=0.0,
811         len_max=1e5,
812         attention_layer="kvrec",
813         logger=print,
814         **kwargs,
815     ):
816         super().__init__()
817
818         assert attention_layer in {
819             "mha",
820             "dumbrec",
821             "kvrec",
822             "caterpillar",
823         }, f"Unknown attention operator {attention_layer}."
824
825         if attention_layer == "caterpillar":
826             assert nb_lines % caterpillar_height == 0
827             self.caterpillar_length = nb_lines // caterpillar_height
828             self.caterpillar_height = caterpillar_height
829         else:
830             self.caterpillar_length = -1
831             self.caterpillar_height = -1
832
833         assert dim_model % nb_heads == 0
834
835         self.embedding = nn.Sequential(
836             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
837             AddPositionalEncoding(len_max),
838         )
839
840         trunk_blocks = []
841
842         def attlayer():
843             if attention_layer == "mha":
844                 return QKVAttention(
845                     dim_model=dim_model,
846                     dim_qk=dim_keys,
847                     dim_v=dim_model // nb_heads,
848                     nb_heads=nb_heads,
849                     causal=causal,
850                     attention_dropout=dropout,
851                     logger=logger,
852                     **kwargs,
853                 )
854             elif attention_layer == "dumbrec":
855                 return DumbRec(
856                     dim_model=dim_model,
857                     dim_qk=dim_keys,
858                     dim_v=dim_model // nb_heads,
859                     nb_heads=nb_heads,
860                     nb_lines=nb_lines,
861                     attention_dropout=dropout,
862                     logger=logger,
863                     **kwargs,
864                 )
865             elif attention_layer == "kvrec":
866                 return KVRec(
867                     dim_model=dim_model,
868                     dim_qk=dim_keys,
869                     dim_v=dim_model // nb_heads,
870                     nb_heads=nb_heads,
871                     nb_lines=nb_lines,
872                     attention_dropout=dropout,
873                     logger=logger,
874                     **kwargs,
875                 )
876             elif attention_layer == "caterpillar":
877                 return Caterpillar(
878                     dim_model=dim_model,
879                     dim_qk=dim_keys,
880                     dim_v=dim_model // nb_heads,
881                     nb_heads=nb_heads,
882                     caterpillar_length=self.caterpillar_length,
883                     caterpillar_height=self.caterpillar_height,
884                     attention_dropout=dropout,
885                     logger=logger,
886                     **kwargs,
887                 )
888             else:
889                 raise ValueError(f"Unknown attention type {attention_layer}.")
890
891         for b in range(nb_blocks):
892             trunk_blocks += [
893                 WithResidual(
894                     CacheWrapper(nn.LayerNorm((dim_model,))),
895                     attlayer(),
896                 ),
897                 WithResidual(
898                     CacheWrapper(
899                         nn.LayerNorm((dim_model,)),
900                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
901                         nn.ReLU(),
902                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
903                         nn.Dropout(dropout),
904                     ),
905                 ),
906             ]
907
908         self.trunk = nn.Sequential(*trunk_blocks)
909
910         self.readout = CacheWrapper(
911             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
912         )
913
914         with torch.no_grad():
915             for m in self.modules():
916                 if isinstance(m, nn.Embedding):
917                     m.weight.normal_(mean=0, std=2e-2)
918                 elif isinstance(m, nn.LayerNorm):
919                     m.bias.zero_()
920                     m.weight.fill_(1.0)
921
922         self.reset_inner_loss()
923
924     def forward(self, bs):
925         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
926
927         # To make the code simpler in the Caterpillar layer, we pad
928         # here. It's unclear if/how much it hurts computationaly by
929         # increasing the sequence length for the other layers
930
931         if self.caterpillar_length > 0:
932             original_nb = bs.nb
933             if bs.nb % self.caterpillar_length > 0:
934                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
935
936             bs = BracketedSequence(
937                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
938                 bs.first + self.caterpillar_length,
939                 bs.nb,
940                 bs.init_cache,
941             )
942
943         bs = self.embedding(bs)
944         bs = self.trunk(bs)
945         bs = self.readout(bs)
946
947         if self.caterpillar_length > 0:
948             bs = BracketedSequence(
949                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
950                 bs.first - self.caterpillar_length,
951                 original_nb,
952                 bs.init_cache,
953             )
954
955         return bs
956
957     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
958     # 1s where tokens should be generated. The others are kept
959     # unchanged.
960
961     def masked_inplace_autoregression(
962         self,
963         input_src,
964         ar_mask_src,
965         forbidden_tokens=None,
966         deterministic_synthesis=False,
967     ):
968         input = input_src.to(self.readout.f.weight.device)
969         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
970         to_generate = (ar_mask.sum(0) > 0).nonzero()
971         if to_generate.min() > 0:
972             self(
973                 BracketedSequence(input, 0, to_generate.min(), True)
974             )  # Needed to initialize the model's cache
975         for s in range(to_generate.min(), to_generate.max() + 1):
976             output = self(BracketedSequence(input, s, 1, s == 0)).x
977             logits = output[:, s]
978             if forbidden_tokens is not None:
979                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
980             if deterministic_synthesis:
981                 t_next = logits.argmax(1)
982             else:
983                 dist = torch.distributions.categorical.Categorical(logits=logits)
984                 t_next = dist.sample()
985             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
986
987         input_src.copy_(input)
988
989     def reset_inner_loss(self):
990         for m in self.modules():
991             if m is not self and hasattr(m, "reset_inner_loss"):
992                 m.reset_inner_loss()
993
994     def get_inner_loss(self):
995         l = torch.tensor([0.0], device=self.readout.f.weight.device)
996         for m in self.modules():
997             if m is not self and hasattr(m, "get_inner_loss"):
998                 l += m.get_inner_loss()
999         return l
1000
1001     def record_attention(self, v=True):
1002         for m in self.modules():
1003             if isinstance(m, QKVAttention):
1004                 m.record_attention = v
1005
1006     def retrieve_attention(self):
1007         a = []
1008         for m in self.modules():
1009             if isinstance(m, QKVAttention):
1010                 a.append(m.a)
1011         return a
1012
1013
1014 ######################################################################
1015
1016 if __name__ == "__main__":
1017     print("Basic check.")
1018
1019     m = Caterpillar(
1020         dim_model=4,
1021         dim_qk=3,
1022         dim_v=7,
1023         nb_heads=1,
1024         caterpillar_length=7,
1025         caterpillar_height=3,
1026         attention_dropout=0.0,
1027     )
1028
1029     m.reset_inner_loss()
1030     x = torch.randn(1, 21 + 2 * 7, 4)
1031     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1032     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1033     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1034     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1035     print((y1 - y2).abs().max())
1036     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1037     exit(0)
1038
1039     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1040
1041     vocabulary_size = 128
1042     x = torch.randint(vocabulary_size, (6, 1024))
1043
1044     model = MyGPT(
1045         vocabulary_size=vocabulary_size,
1046         dim_model=512,
1047         dim_keys=64,
1048         dim_hidden=2048,
1049         nb_heads=8,
1050         nb_lines=128,
1051         nb_blocks=12,
1052         dropout=0.1,
1053         causal=True,
1054     )
1055
1056     x = x.to(device)
1057     model.to(device)
1058
1059     import time, sys
1060
1061     # import torchvision.models as models
1062     # from torch.profiler import profile, record_function, ProfilerActivity
1063
1064     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1065     # with record_function("model_inference"):
1066
1067     model.eval()
1068     for i in range(3):
1069         start_time = time.perf_counter()
1070         for k in range(10):
1071             model(BracketedSequence(x))
1072         duration = time.perf_counter() - start_time
1073         print(duration)
1074         sys.stdout.flush()
1075
1076     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1077     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1078
1079     # print("##############################################################")
1080     # y2 = torch.randn_like(y1)
1081     # for s in range(x.size(1)):
1082     # z = model(BracketedSequence(x, s, 1))
1083     # y2[:, s : s + 1] = z.slice()
1084
1085     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1086
1087 ######################################################################