Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_in,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_in)
203         self.w_qr = randw(nb_heads, dim_qk, dim_in)
204         # self.w_k = randw(nb_heads, dim_qk, dim_in)
205         self.w_v = randw(nb_heads, dim_v, dim_in)
206         self.w_o = randw(dim_v * nb_heads, dim_in)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_in,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_in)
332         self.w_qr = randw(nb_heads, dim_qk, dim_in)
333         self.w_k = randw(nb_heads, dim_qk, dim_in)
334         self.w_v = randw(nb_heads, dim_v, dim_in)
335         self.w_o = randw(dim_v * nb_heads, dim_in)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 def moving_window(x, dim, win_dim, win_size):
445     size, stride = x.size(), x.stride()
446     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
447     size = size[:win_dim] + (win_size,) + size[win_dim:]
448     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
449
450     return x.as_strided(size=size, stride=stride)
451
452
453 ##############################
454
455
456 class Caterpillar(nn.Module):
457     def __init__(
458         self,
459         dim_in,
460         dim_qk,
461         dim_v,
462         nb_heads,
463         caterpillar_length,
464         caterpillar_height,
465         attention_dropout=0.0,
466         len_max=1e5,
467     ):
468         super().__init__()
469
470         warnings.warn("Caterpillar", RuntimeWarning)
471
472         def randw(*d):
473             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
474
475         self.caterpillar_length = caterpillar_length
476         self.caterpillar_height = caterpillar_height
477         self.attention_dropout = attention_dropout
478
479         self.w_G = randw(nb_heads, caterpillar_height, dim_in)
480         self.b_G = nn.Parameter(
481             torch.full(
482                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
483             )
484         )
485
486         self.w_K = randw(nb_heads, dim_qk, dim_in)
487         self.w_V = randw(nb_heads, dim_v, dim_in)
488         self.w_Q = randw(nb_heads, dim_qk, dim_in)
489         self.w_O = randw(dim_v * nb_heads, dim_in)
490
491         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
492         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
493
494     def reset_inner_loss(self):
495         self.acc_attention = 0
496         self.acc_nb = 0
497
498     def get_inner_loss(self):
499         # warnings.warn("l2 regularization", RuntimeWarning)
500         # return (self.acc_attention / self.acc_nb).pow(2).sum()
501         return torch.tensor([0], device=self.w_Q.device)
502
503     def forward(self, bs):
504         # Dimensions to make the source a bit clearer, that's needed
505
506         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
507
508         N = bs.x.size(0)
509         T = bs.x.size(1)
510         DV = self.w_V.size(1)
511         DK = self.w_K.size(1)
512         Dout = self.w_O.size(1)
513         CH = self.caterpillar_height
514         CL = self.caterpillar_length
515
516         assert (
517             t0 >= CL and (t1 - t0) % CL == 0
518         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
519
520         if bs.init_cache:
521             self.rec_V = X.new_zeros(N, CH, T, DV)
522             self.rec_K = X.new_zeros(N, CH, T, DK)
523             # We start the recurrent sequences with optimizable
524             # initial values. No idea if it helps.
525             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
526             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
527
528             self.cache_Y = X.new_zeros(N, T, Dout)
529
530         ######################################################################
531         # Compute the recurrent state
532
533         # This is the Gating sequence that modulates if they key and
534         # values should be stored in one of the CH pairs of the
535         # current stack. The CH gating values are independent, which
536         # means that the same thing could be stored up to CH times or
537         # not at all
538
539         G = (
540             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
541         ).sigmoid()
542
543         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
544         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
545
546         # We prepare the arguments for the parallel scan
547
548         A = 1 - G.sum(1)
549         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
550         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
551
552         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
553         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
554
555         # Here there is a trick: The parallel scan operates with a
556         # period of L, so we split the sequence indexing in two axes,
557         # the second of size CL, and run the parallel scan using the
558         # other alone as the sequence index.
559
560         A = A.unflatten(2, (-1, CL))
561         gated_V = gated_V.unflatten(2, (-1, CL))
562         gated_K = gated_K.unflatten(2, (-1, CL))
563
564         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
565         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
566
567         # Put back the sequence index
568
569         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
570         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
571
572         ######################################################################
573         # compute the readout
574
575         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
576
577         # We build tensors NxHxTxFxL where N is the sample index, H
578         # the head, T the time, F the row in the caterpillar, and L
579         # the column in the caterpillar
580
581         windowed_V = moving_window(
582             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
583         )
584
585         windowed_K = moving_window(
586             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
587         )
588
589         # We have an attention score for each of the CHxCL values
590
591         ar = torch.einsum(
592             "nhtd,nftld->nhtfl",
593             Q,
594             windowed_K,
595         ) / math.sqrt(DK)
596
597         # softmax can operate only on one dimension, hence the
598         # flattening
599
600         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
601
602         ar = F.dropout(ar, self.attention_dropout, self.training)
603
604         # Compute the output for each head, flatten to concatenate
605
606         Y = torch.einsum(
607             "nhtfl,nftld->nthd",
608             ar,
609             windowed_V,
610         ).flatten(2)
611
612         # Compute the final output
613
614         self.cache_Y[:, t0:t1] = Y @ self.w_O
615
616         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
617
618
619 ##############################
620
621
622 class QKVAttention(nn.Module):
623     def __init__(
624         self,
625         dim_in,
626         dim_qk,
627         dim_v,
628         nb_heads=1,
629         causal=False,
630         attention_dropout=0.0,
631     ):
632         super().__init__()
633
634         def randw(*d):
635             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
636
637         self.causal = causal
638         self.attention_dropout = attention_dropout
639         self.record_attention = False
640
641         self.w_q = randw(nb_heads, dim_qk, dim_in)
642         self.w_k = randw(nb_heads, dim_qk, dim_in)
643         self.w_v = randw(nb_heads, dim_v, dim_in)
644         self.w_o = randw(dim_v * nb_heads, dim_in)
645
646     def forward(self, bs):
647         x_q = bs.x
648
649         assert (
650             self.causal or bs.complete()
651         ), "Partial evaluation is only possible for causal models"
652
653         if bs.init_cache:
654             self.cache_k = x_q.new_zeros(
655                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
656             )
657             self.cache_v = x_q.new_zeros(
658                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
659             )
660             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
661
662         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
663
664         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
665             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
666         )
667         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
668             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
669         )
670
671         a = torch.einsum(
672             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
673         ) / math.sqrt(self.w_q.size(1))
674
675         if self.causal:
676             if bs.init_cache:
677                 self.cache_attzero = (
678                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
679                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
680                 )
681             a = a.masked_fill(
682                 self.cache_attzero[
683                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
684                 ],
685                 float("-inf"),
686             )
687
688         a = a.softmax(dim=3)
689
690         if self.record_attention:
691             self.a = a
692
693         a = F.dropout(a, self.attention_dropout, self.training)
694
695         y = torch.einsum(
696             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
697         ).flatten(2)
698
699         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
700
701         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
702
703
704 ##############################
705
706
707 class MyGPT(nn.Module):
708     def __init__(
709         self,
710         vocabulary_size,
711         dim_model,
712         dim_keys,
713         dim_hidden,
714         nb_heads,
715         nb_blocks,
716         nb_lines=None,
717         caterpillar_height=None,
718         dim_rec_v=-1,
719         causal=False,
720         dropout=0.0,
721         len_max=1e5,
722         attention_layer="kvrec",
723     ):
724         super().__init__()
725
726         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
727
728         if attention_layer == "caterpillar":
729             assert nb_lines % caterpillar_height == 0
730             self.caterpillar_length = nb_lines // caterpillar_height
731             self.caterpillar_height = caterpillar_height
732         else:
733             self.caterpillar_length = -1
734             self.caterpillar_height = -1
735
736         assert dim_model % nb_heads == 0
737
738         self.embedding = nn.Sequential(
739             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
740             AddPositionalEncoding(len_max),
741         )
742
743         trunk_blocks = []
744
745         def attlayer():
746             if attention_layer == "mha":
747                 return QKVAttention(
748                     dim_in=dim_model,
749                     dim_qk=dim_keys,
750                     dim_v=dim_model // nb_heads,
751                     nb_heads=nb_heads,
752                     causal=causal,
753                     attention_dropout=dropout,
754                 )
755             elif attention_layer == "dumbrec":
756                 return DumbRec(
757                     dim_in=dim_model,
758                     dim_qk=dim_keys,
759                     dim_v=dim_rec_v,
760                     nb_heads=nb_heads,
761                     nb_lines=nb_lines,
762                     attention_dropout=dropout,
763                 )
764             elif attention_layer == "kvrec":
765                 return KVRec(
766                     dim_in=dim_model,
767                     dim_qk=dim_keys,
768                     dim_v=dim_rec_v,
769                     nb_heads=nb_heads,
770                     nb_lines=nb_lines,
771                     attention_dropout=dropout,
772                 )
773             elif attention_layer == "caterpillar":
774                 return Caterpillar(
775                     dim_in=dim_model,
776                     dim_qk=dim_keys,
777                     dim_v=dim_rec_v,
778                     nb_heads=nb_heads,
779                     caterpillar_length=self.caterpillar_length,
780                     caterpillar_height=self.caterpillar_height,
781                     attention_dropout=dropout,
782                 )
783             else:
784                 raise ValueError(f"Unknown attention type {attention_layer}.")
785
786         for b in range(nb_blocks):
787             trunk_blocks += [
788                 WithResidual(
789                     CacheWrapper(nn.LayerNorm((dim_model,))),
790                     attlayer(),
791                 ),
792                 WithResidual(
793                     CacheWrapper(
794                         nn.LayerNorm((dim_model,)),
795                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
796                         nn.ReLU(),
797                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
798                         nn.Dropout(dropout),
799                     ),
800                 ),
801             ]
802
803         self.trunk = nn.Sequential(*trunk_blocks)
804
805         self.readout = CacheWrapper(
806             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
807         )
808
809         with torch.no_grad():
810             for m in self.modules():
811                 if isinstance(m, nn.Embedding):
812                     m.weight.normal_(mean=0, std=2e-2)
813                 elif isinstance(m, nn.LayerNorm):
814                     m.bias.zero_()
815                     m.weight.fill_(1.0)
816
817         self.reset_inner_loss()
818
819     def forward(self, bs):
820         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
821
822         # To make the code simpler in the Caterpillar layer, we pad
823         # here. It's unclear if/how much it hurts computationaly by
824         # increasing the sequence length for the other layers
825
826         if self.caterpillar_length > 0:
827             original_nb = bs.nb
828             if bs.nb % self.caterpillar_length > 0:
829                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
830
831             bs = BracketedSequence(
832                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
833                 bs.first + self.caterpillar_length,
834                 bs.nb,
835                 bs.init_cache,
836             )
837
838         bs = self.embedding(bs)
839         bs = self.trunk(bs)
840         bs = self.readout(bs)
841
842         if self.caterpillar_length > 0:
843             bs = BracketedSequence(
844                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
845                 bs.first - self.caterpillar_length,
846                 original_nb,
847                 bs.init_cache,
848             )
849
850         return bs
851
852     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
853     # 1s where tokens should be generated. The others are kept
854     # unchanged.
855
856     def masked_inplace_autoregression(
857         self,
858         input_src,
859         ar_mask_src,
860         forbidden_tokens=None,
861         deterministic_synthesis=False,
862     ):
863         input = input_src.to(self.readout.f.weight.device)
864         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
865         to_generate = (ar_mask.sum(0) > 0).nonzero()
866         if to_generate.min() > 0:
867             self(
868                 BracketedSequence(input, 0, to_generate.min(), True)
869             )  # Needed to initialize the model's cache
870         for s in range(to_generate.min(), to_generate.max() + 1):
871             output = self(BracketedSequence(input, s, 1, s == 0)).x
872             logits = output[:, s]
873             if forbidden_tokens is not None:
874                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
875             if deterministic_synthesis:
876                 t_next = logits.argmax(1)
877             else:
878                 dist = torch.distributions.categorical.Categorical(logits=logits)
879                 t_next = dist.sample()
880             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
881
882         input_src.copy_(input)
883
884     def reset_inner_loss(self):
885         for m in self.modules():
886             if m is not self and hasattr(m, "reset_inner_loss"):
887                 m.reset_inner_loss()
888
889     def get_inner_loss(self):
890         l = torch.tensor([0.0], device=self.readout.f.weight.device)
891         for m in self.modules():
892             if m is not self and hasattr(m, "get_inner_loss"):
893                 l += m.get_inner_loss()
894         return l
895
896     def record_attention(self, v=True):
897         for m in self.modules():
898             if isinstance(m, QKVAttention):
899                 m.record_attention = v
900
901     def retrieve_attention(self):
902         a = []
903         for m in self.modules():
904             if isinstance(m, QKVAttention):
905                 a.append(m.a)
906         return a
907
908
909 ######################################################################
910
911 if __name__ == "__main__":
912     print("Basic check.")
913
914     m = Caterpillar(
915         dim_in=4,
916         dim_qk=3,
917         dim_v=7,
918         nb_heads=1,
919         caterpillar_length=7,
920         caterpillar_height=3,
921         attention_dropout=0.0,
922     )
923
924     m.reset_inner_loss()
925     x = torch.randn(1, 21 + 2 * 7, 4)
926     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
927     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
928     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
929     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
930     print((y1 - y2).abs().max())
931     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
932     exit(0)
933
934     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
935
936     vocabulary_size = 128
937     x = torch.randint(vocabulary_size, (6, 1024))
938
939     model = MyGPT(
940         vocabulary_size=vocabulary_size,
941         dim_model=512,
942         dim_keys=64,
943         dim_hidden=2048,
944         nb_heads=8,
945         nb_lines=128,
946         nb_blocks=12,
947         dropout=0.1,
948         causal=True,
949     )
950
951     x = x.to(device)
952     model.to(device)
953
954     import time, sys
955
956     # import torchvision.models as models
957     # from torch.profiler import profile, record_function, ProfilerActivity
958
959     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
960     # with record_function("model_inference"):
961
962     model.eval()
963     for i in range(3):
964         start_time = time.perf_counter()
965         for k in range(10):
966             model(BracketedSequence(x))
967         duration = time.perf_counter() - start_time
968         print(duration)
969         sys.stdout.flush()
970
971     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
972     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
973
974     # print("##############################################################")
975     # y2 = torch.randn_like(y1)
976     # for s in range(x.size(1)):
977     # z = model(BracketedSequence(x, s, 1))
978     # y2[:, s : s + 1] = z.slice()
979
980     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
981
982 ######################################################################