Update.
[mygptrnn.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
12
13 import math, warnings
14
15 import torch, einops
16
17 from torch import nn
18 from torch.nn import functional as F
19
20 import ffutils
21
22 # import memload
23
24 ######################################################################
25
26 # A BracketedSequence is a BxTx... tensor with a first and a nb time
27 # steps to compute.
28
29 # Modules able to process it expect that they will have to process a
30 # first bracket starting at t=0, followed by a succession of brackets
31 # that move forward in time, do not overlap, and cover the axis T with
32 # no holes.
33 #
34 # Although it is more general, for a classical prompt-conditioned
35 # auto-regressive process it will be a first bracket starting at 0 and
36 # of arbitrary length for the "prompt", followed by brackets of length
37 # 1 for the successive tokens.
38 #
39 # Modules able to process brackets may implement a cache that is
40 # resetted when the input bracket starts at t=0
41
42
43 class BracketedSequence:
44     def __init__(self, x, first=None, nb=None, init_cache=None):
45         self.x = x
46         assert (first is None and nb is None and init_cache is None) or (
47             first is not None and nb is not None and init_cache is not None
48         )
49
50         self.first = 0 if first is None else first
51         self.nb = x.size(1) if nb is None else nb
52         self.init_cache = True if init_cache is None else init_cache
53
54     def slice(self):
55         return self.x[:, self.first : self.first + self.nb]
56
57     def complete(self):
58         return self.first == 0 and self.nb == self.x.size(1)
59
60
61 ######################################################################
62
63
64 class CacheWrapper(nn.Module):
65     def __init__(self, *f):
66         super().__init__()
67         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
68
69     def forward(self, bs):
70         if bs.init_cache:
71             y = self.f(bs.slice())
72             self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
73             self.cache_y[:, bs.first : bs.first + bs.nb] = y
74         else:
75             assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
76             assert bs.first + bs.nb <= self.cache_y.size(1)
77             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
78
79         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
80
81
82 ##############################
83
84
85 class WithResidual(nn.Module):
86     def __init__(self, *f):
87         super().__init__()
88         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
89
90     def forward(self, bs):
91         return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
92
93
94 ##############################
95
96
97 class AddPositionalEncoding(nn.Module):
98     def __init__(self, len_max):
99         super().__init__()
100         self.len_max = len_max
101
102     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
103
104     def forward(self, bs):
105         if bs.init_cache:
106             t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
107                 :, None
108             ]
109             j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
110                 None, :
111             ]
112             k = j % 2
113             self.pe = torch.sin(
114                 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
115             )
116             self.cache_y = bs.x.new(bs.x.size())
117
118         self.cache_y[:, bs.first : bs.first + bs.nb] = (
119             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
120         )
121
122         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
123
124
125 import pscan
126
127
128 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
129
130
131 def pscan_dim(A, X, Y_init, dim=-2):
132     s = X.size()
133     a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
134
135     A = A.reshape(a, T, *s[dim + 1 : -1])
136     X = X.reshape(a, T, *s[dim + 1 : -1], -1)
137
138     if Y_init is None:
139         Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
140     else:
141         Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
142
143     Y = pscan.pscan(A, X, Y_init).reshape(s)
144
145     return Y
146
147
148 def pscan_shape(A, X, Y_init):
149     s = X.size()
150     A = A.reshape(-1, s[-2])
151     X = X.reshape(-1, s[-2], s[-1])
152
153     if Y_init is None:
154         Y_init = X.new_zeros(X.size(0), s[-1])
155     else:
156         Y_init = Y_init.reshape(-1, s[-1])
157
158     Y = pscan.pscan(A, X, Y_init).reshape(s)
159
160     return Y
161
162
163 def nsum_shape(X, Y_init):
164     s = X.size()
165     X = X.reshape(-1, s[-2], s[-1])  # ntd
166
167     Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
168     result = []
169
170     for k in range(X.size(1)):
171         Y = Y + X[:, k]
172         Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
173         result.append(Y)
174
175     return torch.cat(result, dim=1).reshape(s)
176
177
178 ##############################
179
180
181 class DumbRec(nn.Module):
182     def __init__(
183         self,
184         dim_model,
185         dim_qk,
186         dim_v,
187         nb_heads,
188         nb_lines,
189         attention_dropout=0.0,
190         len_max=1e5,
191     ):
192         super().__init__()
193
194         def randw(*d):
195             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
196
197         self.nb_lines = nb_lines
198         self.attention_dropout = attention_dropout
199
200         self.k_star = randw(nb_lines, dim_qk)
201
202         self.w_qw = randw(nb_heads, dim_qk, dim_model)
203         self.w_qr = randw(nb_heads, dim_qk, dim_model)
204         # self.w_k = randw(nb_heads, dim_qk, dim_model)
205         self.w_v = randw(nb_heads, dim_v, dim_model)
206         self.w_o = randw(dim_v * nb_heads, dim_model)
207
208     def reset_inner_loss(self):
209         self.acc_attention = 0
210         self.acc_nb = 0
211
212     def get_inner_loss(self):
213         warnings.warn("l2 regularization", RuntimeWarning)
214         return (self.acc_attention / self.acc_nb).pow(2).sum()
215         # return torch.tensor([0], device=self.w_qw.device)
216
217     def forward(self, bs):
218         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
219
220         if bs.init_cache:
221             self.rec_v = x_q.new_zeros(
222                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
223             )
224             # self.rec_k = x_q.new_zeros(
225             # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
226             # )
227             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
228
229         ######################################################################
230         # Prepare the keys
231
232         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
233
234         warnings.warn("rotating key barrel", RuntimeWarning)
235         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
236         t_barrel = torch.arange(t0, t1, device=k_star.device)
237         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
238         l_barrel = (
239             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
240         ) % k_star.size(0)
241         k_star = k_star[l_barrel, t_barrel]
242
243         ######################################################################
244         # Compute the recurrent state
245
246         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
247
248         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
249         # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
250
251         aw = torch.einsum(
252             "nhtd,ltd->nhlt",
253             qw,
254             k_star,
255         ) / math.sqrt(self.w_qw.size(1))
256
257         aw = aw.softmax(dim=2)  # nhlt
258
259         if self.train:
260             self.acc_attention += aw.sum(dim=(0, 1, 3))
261             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
262
263         aw = F.dropout(aw, self.attention_dropout, self.training)
264
265         A = 1 - aw.sum(dim=1)  # nlt
266
267         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
268         # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
269
270         if t0 == 0:
271             V0 = None
272             # K0 = None
273         else:
274             V0 = self.rec_v[:, :, t0 - 1]
275             # K0 = self.rec_k[:, :, t0 - 1]
276
277         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
278         # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
279
280         ######################################################################
281         # compute the readout
282
283         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
284
285         ar = torch.einsum(
286             "nhtd,ld->nhlt",
287             qr,
288             # self.rec_k[:, :, t0:t1],
289             self.k_star,
290         ) / math.sqrt(self.w_qr.size(1))
291
292         ar = ar.softmax(dim=2)  # nhlt
293
294         ar = F.dropout(ar, self.attention_dropout, self.training)
295
296         y = torch.einsum(
297             "nhlt,nltd->nthd",
298             ar,
299             self.rec_v[:, :, t0:t1],
300         ).flatten(2)
301
302         self.cache_y[:, t0:t1] = y @ self.w_o
303
304         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
305
306
307 ##############################
308
309
310 class KVRec(nn.Module):
311     def __init__(
312         self,
313         dim_model,
314         dim_qk,
315         dim_v,
316         nb_heads,
317         nb_lines,
318         attention_dropout=0.0,
319         len_max=1e5,
320     ):
321         super().__init__()
322
323         def randw(*d):
324             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
325
326         self.nb_lines = nb_lines
327         self.attention_dropout = attention_dropout
328
329         self.k_star = randw(nb_lines, dim_qk)
330
331         self.w_qw = randw(nb_heads, dim_qk, dim_model)
332         self.w_qr = randw(nb_heads, dim_qk, dim_model)
333         self.w_k = randw(nb_heads, dim_qk, dim_model)
334         self.w_v = randw(nb_heads, dim_v, dim_model)
335         self.w_o = randw(dim_v * nb_heads, dim_model)
336
337     def reset_inner_loss(self):
338         self.acc_attention = 0
339         self.acc_nb = 0
340
341     def get_inner_loss(self):
342         warnings.warn("l2 regularization", RuntimeWarning)
343         return (self.acc_attention / self.acc_nb).pow(2).sum()
344         # return torch.tensor([0], device=self.w_qw.device)
345         # warnings.warn("side regularization", RuntimeWarning)
346         # return (
347         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
348         # )
349         # return torch.tensor([0], device=self.w_qw.device)
350
351     def forward(self, bs):
352         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
353
354         if bs.init_cache:
355             self.rec_v = x_q.new_zeros(
356                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
357             )
358             self.rec_k = x_q.new_zeros(
359                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
360             )
361             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
362
363         ######################################################################
364         # Prepare the keys
365
366         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
367
368         warnings.warn("rotating key barrel", RuntimeWarning)
369         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
370         t_barrel = torch.arange(t0, t1, device=k_star.device)
371         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
372         l_barrel = (
373             torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
374         ) % k_star.size(0)
375         k_star = k_star[l_barrel, t_barrel]
376
377         ######################################################################
378         # Compute the recurrent state
379
380         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
381
382         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
383         k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
384
385         aw = torch.einsum(
386             "nhtd,ltd->nhlt",
387             qw,
388             k_star,
389         ) / math.sqrt(self.w_qw.size(1))
390
391         aw = aw.softmax(dim=2)  # nhlt
392
393         if self.train:
394             # We want all the memory lines to be used similarly
395             self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
396             self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
397
398         aw = F.dropout(aw, self.attention_dropout, self.training)
399
400         A = 1 - aw.sum(dim=1)  # nlt
401
402         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
403         K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
404
405         if t0 == 0:
406             V0 = None
407             K0 = None
408         else:
409             V0 = self.rec_v[:, :, t0 - 1]
410             K0 = self.rec_k[:, :, t0 - 1]
411
412         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
413         self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
414
415         ######################################################################
416         # compute the readout
417
418         qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
419
420         ar = torch.einsum(
421             "nhtd,nltd->nhlt",
422             qr,
423             self.rec_k[:, :, t0:t1],
424         ) / math.sqrt(self.w_qr.size(1))
425
426         ar = ar.softmax(dim=2)  # nhlt
427
428         ar = F.dropout(ar, self.attention_dropout, self.training)
429
430         y = torch.einsum(
431             "nhlt,nltd->nthd",
432             ar,
433             self.rec_v[:, :, t0:t1],
434         ).flatten(2)
435
436         self.cache_y[:, t0:t1] = y @ self.w_o
437
438         return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
439
440
441 ##############################
442
443
444 # Returns a tensor with an additional index at rank win_dim, that move
445 # along the same dimension as dim, on a domain {0...win_size-1}, and
446 # dim is restricted on a domain reduced by win_size-1 values.
447
448
449 def moving_window(x, dim, win_dim, win_size):
450     size, stride = x.size(), x.stride()
451     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
452     size = size[:win_dim] + (win_size,) + size[win_dim:]
453     stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
454
455     return x.as_strided(size=size, stride=stride)
456
457
458 ##############################
459
460 # This is one order of magnitude more complicated than I expected, not
461 # elegant, slow, hopefully not buggy
462
463
464 def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device):
465     # starting flash backs
466     fb_start = (torch.rand(N, CH, t1 - t0, device=device) <= proba).long()
467     fb_start[:, :, -CL:] = 0
468     fb_start[:, :, :CL] = 0
469
470     # Remove series longer than CL
471     fb_body = fb_start.clone()
472     fb_body[:, :, CL + 1 :] -= fb_start[:, :, : -(CL + 1)]
473     fb_body = fb_body.cumsum(dim=2)
474     fb_start = fb_start * (fb_body == 1)
475
476     # Set a origin source time (starting time of the chunck to copy
477     # here) We set it as the current time minus a multiple of CL to be
478     # consistent with the "rolling" caterpillar
479     t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :]
480     src_time = fb_start * (
481         t
482         - CL
483         * (
484             1
485             + (
486                 torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1)
487             ).long()
488         )
489     )
490     src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL]
491     src_time = src_time.cumsum(dim=2)
492
493     src_head = fb_start * torch.randint(H, fb_start.size(), device=fb_start.device)
494     src_head[:, :, CL:] -= src_head.clone()[:, :, :-CL]
495     src_head = src_head.cumsum(dim=2)
496
497     # combine
498     src_delta = fb_start.clone()
499     src_delta[:, :, CL:] -= fb_start[:, :, :-CL]
500     src_delta = src_delta.cumsum(dim=2)
501     src_delta[:, :, CL:] -= CL * fb_start[:, :, :-CL]
502     src_time += src_delta.cumsum(dim=2) - 1
503
504     return src_time, src_head
505
506
507 def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
508     N, H, CH = V.size(0), V.size(1), rec_V.size(1)
509
510     fbt, fbh = flash_back_time_src(N, H, t0, t1, CL, CH, proba, rec_V.device)
511
512     fbt_V = fbt[:, :, :, None]
513     fbh_V = fbh[:, :, :, None]
514     t = fbt_V.clamp(min=0)
515     n = torch.arange(V.size(0), device=V.device)[:, None, None, None]
516     d = torch.arange(V.size(3), device=V.device)[None, None, None, :]
517     q = V[:, :, t0:t1][n, fbh_V, t, d]
518     rec_V[:, :, t0:t1] = q * (fbt_V >= 0) + rec_V[:, :, t0:t1] * (fbt_V < 0)
519
520     fbt_K = fbt[:, :, :, None]
521     fbh_K = fbh[:, :, :, None]
522     t = fbt_K.clamp(min=0)
523     n = torch.arange(K.size(0), device=K.device)[:, None, None, None]
524     d = torch.arange(K.size(3), device=K.device)[None, None, None, :]
525     q = K[:, :, t0:t1][n, fbh_K, t, d]
526     rec_K[:, :, t0:t1] = q * (fbt_K >= 0) + rec_K[:, :, t0:t1] * (fbt_K < 0)
527
528
529 ######################################################################
530
531
532 class Caterpillar(nn.Module):
533     def __init__(
534         self,
535         dim_model,
536         dim_qk,
537         dim_v,
538         nb_heads,
539         caterpillar_length,
540         caterpillar_height,
541         attention_dropout=0.0,
542         len_max=1e5,
543     ):
544         super().__init__()
545
546         warnings.warn("Caterpillar", RuntimeWarning)
547
548         def randw(*d):
549             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
550
551         self.caterpillar_length = caterpillar_length
552         self.caterpillar_height = caterpillar_height
553         self.attention_dropout = attention_dropout
554
555         warnings.warn("flash back", RuntimeWarning)
556         self.proba_flashback = 0.1
557
558         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
559         self.b_G = nn.Parameter(
560             torch.full(
561                 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
562             )
563         )
564
565         self.w_K = randw(nb_heads, dim_qk, dim_model)
566         self.w_V = randw(nb_heads, dim_v, dim_model)
567         self.w_Q = randw(nb_heads, dim_qk, dim_model)
568         self.w_O = randw(dim_v * nb_heads, dim_model)
569
570         self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
571         self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
572
573     def reset_inner_loss(self):
574         self.acc_attention = 0
575         self.acc_nb = 0
576
577     def get_inner_loss(self):
578         # warnings.warn("l2 regularization", RuntimeWarning)
579         # return (self.acc_attention / self.acc_nb).pow(2).sum()
580         return torch.tensor([0], device=self.w_Q.device)
581
582     def forward(self, bs):
583         # Dimensions to make the source a bit clearer, that's needed
584
585         X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
586
587         N = bs.x.size(0)
588         T = bs.x.size(1)
589         H = self.w_V.size(0)
590         DV = self.w_V.size(1)
591         DK = self.w_K.size(1)
592         DM = self.w_O.size(1)
593         CH = self.caterpillar_height
594         CL = self.caterpillar_length
595
596         assert (
597             t0 >= CL and (t1 - t0) % CL == 0
598         ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
599
600         # We cache values to deal efficiently with auto-regression
601
602         if bs.init_cache:
603             self.rec_V = X.new_zeros(N, CH, T, DV)
604             self.rec_K = X.new_zeros(N, CH, T, DK)
605             # We start the recurrent sequences with optimizable
606             # initial values. No idea if it helps.
607             self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
608             self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
609
610             self.cache_Y = X.new_zeros(N, T, DM)
611
612         ######################################################################
613         # Compute the recurrent state
614
615         # This is the Gating sequence that modulates the storing of
616         # the new key and value in the CH pairs of the current
617         # stack. The CH gating values are independent, which means
618         # that the current K/V could be stored in multiple pairs of the
619         # recurrent state, or not at all.
620
621         G = (
622             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
623         ).sigmoid()
624
625         # That bas a bad idea
626         # G = F.dropout(G, self.attention_dropout, self.training)
627
628         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
629         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
630
631         # We prepare the arguments for the parallel scan
632
633         A = 1 - G.sum(1)
634         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
635         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
636
637         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
638         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
639
640         # Here there is a trick: Since the stack at time t is computed
641         # by updating that at time t-L, the parallel scan operates
642         # with a period of L. To do so we split the time indexing in
643         # two axes, the second of size CL, and run the parallel scan
644         # using the other as the sequence index.
645
646         A = A.unflatten(2, (-1, CL))
647         gated_V = gated_V.unflatten(2, (-1, CL))
648         gated_K = gated_K.unflatten(2, (-1, CL))
649
650         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
651         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
652
653         # Put back the sequence index
654
655         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
656         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
657
658         if self.training and self.proba_flashback:
659             # This piece of code makes the assumption that there is
660             # nothing informative before t0, otherwise we'd have to
661             # implement a cache for V and K too. This should not be
662             # too much of a problem since this is used only during
663             # train, where full sequence are available
664
665             # insert_flash_back(
666             # self.rec_V,
667             # V,
668             # self.rec_K,
669             # K,
670             # t0,
671             # t1,
672             # CL,
673             # proba=self.proba_flashback / CL,
674             # )
675
676             n = torch.arange(N, device=X.device)[:, None, None, None]
677             t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
678             dv = torch.arange(DV, device=X.device)[None, None, None, :]
679             dk = torch.arange(DK, device=X.device)[None, None, None, :]
680
681             u = (
682                 torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
683             ) * CL
684
685             src_time = t - u - t0
686             src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
687
688             mask_V = (
689                 torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
690             ).long()
691             self.rec_V[:, :, t0:t1] = (
692                 mask_V * V[n, src_head, src_time, dv]
693                 + (1 - mask_V) * self.rec_V[:, :, t0:t1]
694             )
695
696             mask_K = (
697                 torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback
698             ).long()
699             self.rec_K[:, :, t0:t1] = (
700                 mask_K * K[n, src_head, src_time, dk]
701                 + (1 - mask_K) * self.rec_K[:, :, t0:t1]
702             )
703
704         ######################################################################
705         # compute the readout
706
707         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
708
709         # We build tensors NxHxTxFxL where N is the sample index, H
710         # the head, T the time, F the row in the caterpillar, and L
711         # the column in the caterpillar
712
713         windowed_V = moving_window(
714             self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
715         )
716
717         windowed_K = moving_window(
718             self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
719         )
720
721         # We have an attention score for each of the CHxCL values
722
723         ar = torch.einsum(
724             "nhtd,nftld->nhtfl",
725             Q,
726             windowed_K,
727         ) / math.sqrt(DK)
728
729         # softmax can operate only on one dimension, hence the
730         # flattening
731
732         ar = ar.flatten(3).softmax(dim=3).view(ar.size())
733
734         ar = F.dropout(ar, self.attention_dropout, self.training)
735
736         # Compute the output for each head, flatten to concatenate
737
738         Y = torch.einsum(
739             "nhtfl,nftld->nthd",
740             ar,
741             windowed_V,
742         ).flatten(2)
743
744         # Compute the final output
745
746         self.cache_Y[:, t0:t1] = Y @ self.w_O
747
748         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
749
750
751 ##############################
752
753
754 class QKVAttention(nn.Module):
755     def __init__(
756         self,
757         dim_model,
758         dim_qk,
759         dim_v,
760         nb_heads=1,
761         causal=False,
762         attention_dropout=0.0,
763     ):
764         super().__init__()
765
766         def randw(*d):
767             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
768
769         self.causal = causal
770         self.attention_dropout = attention_dropout
771         self.record_attention = False
772
773         self.w_q = randw(nb_heads, dim_qk, dim_model)
774         self.w_k = randw(nb_heads, dim_qk, dim_model)
775         self.w_v = randw(nb_heads, dim_v, dim_model)
776         self.w_o = randw(dim_v * nb_heads, dim_model)
777
778     def forward(self, bs):
779         x_q = bs.x
780
781         assert (
782             self.causal or bs.complete()
783         ), "Partial evaluation is only possible for causal models"
784
785         if bs.init_cache:
786             self.cache_k = x_q.new_zeros(
787                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
788             )
789             self.cache_v = x_q.new_zeros(
790                 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
791             )
792             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
793
794         q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
795
796         self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
797             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
798         )
799         self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
800             "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
801         )
802
803         a = torch.einsum(
804             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
805         ) / math.sqrt(self.w_q.size(1))
806
807         if self.causal:
808             if bs.init_cache:
809                 self.cache_attzero = (
810                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
811                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
812                 )
813             a = a.masked_fill(
814                 self.cache_attzero[
815                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
816                 ],
817                 float("-inf"),
818             )
819
820         a = a.softmax(dim=3)
821
822         if self.record_attention:
823             self.a = a
824
825         a = F.dropout(a, self.attention_dropout, self.training)
826
827         y = torch.einsum(
828             "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
829         ).flatten(2)
830
831         self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
832
833         return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
834
835
836 ##############################
837
838
839 class MyGPT(nn.Module):
840     def __init__(
841         self,
842         vocabulary_size,
843         dim_model,
844         dim_keys,
845         dim_hidden,
846         nb_heads,
847         nb_blocks,
848         nb_lines=None,
849         caterpillar_height=None,
850         dim_rec_v=-1,
851         causal=False,
852         dropout=0.0,
853         len_max=1e5,
854         attention_layer="kvrec",
855     ):
856         super().__init__()
857
858         assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
859
860         if attention_layer == "caterpillar":
861             assert nb_lines % caterpillar_height == 0
862             self.caterpillar_length = nb_lines // caterpillar_height
863             self.caterpillar_height = caterpillar_height
864         else:
865             self.caterpillar_length = -1
866             self.caterpillar_height = -1
867
868         assert dim_model % nb_heads == 0
869
870         self.embedding = nn.Sequential(
871             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
872             AddPositionalEncoding(len_max),
873         )
874
875         trunk_blocks = []
876
877         def attlayer():
878             if attention_layer == "mha":
879                 return QKVAttention(
880                     dim_model=dim_model,
881                     dim_qk=dim_keys,
882                     dim_v=dim_model // nb_heads,
883                     nb_heads=nb_heads,
884                     causal=causal,
885                     attention_dropout=dropout,
886                 )
887             elif attention_layer == "dumbrec":
888                 return DumbRec(
889                     dim_model=dim_model,
890                     dim_qk=dim_keys,
891                     dim_v=dim_rec_v,
892                     nb_heads=nb_heads,
893                     nb_lines=nb_lines,
894                     attention_dropout=dropout,
895                 )
896             elif attention_layer == "kvrec":
897                 return KVRec(
898                     dim_model=dim_model,
899                     dim_qk=dim_keys,
900                     dim_v=dim_rec_v,
901                     nb_heads=nb_heads,
902                     nb_lines=nb_lines,
903                     attention_dropout=dropout,
904                 )
905             elif attention_layer == "caterpillar":
906                 return Caterpillar(
907                     dim_model=dim_model,
908                     dim_qk=dim_keys,
909                     dim_v=dim_rec_v,
910                     nb_heads=nb_heads,
911                     caterpillar_length=self.caterpillar_length,
912                     caterpillar_height=self.caterpillar_height,
913                     attention_dropout=dropout,
914                 )
915             else:
916                 raise ValueError(f"Unknown attention type {attention_layer}.")
917
918         for b in range(nb_blocks):
919             trunk_blocks += [
920                 WithResidual(
921                     CacheWrapper(nn.LayerNorm((dim_model,))),
922                     attlayer(),
923                 ),
924                 WithResidual(
925                     CacheWrapper(
926                         nn.LayerNorm((dim_model,)),
927                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
928                         nn.ReLU(),
929                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
930                         nn.Dropout(dropout),
931                     ),
932                 ),
933             ]
934
935         self.trunk = nn.Sequential(*trunk_blocks)
936
937         self.readout = CacheWrapper(
938             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
939         )
940
941         with torch.no_grad():
942             for m in self.modules():
943                 if isinstance(m, nn.Embedding):
944                     m.weight.normal_(mean=0, std=2e-2)
945                 elif isinstance(m, nn.LayerNorm):
946                     m.bias.zero_()
947                     m.weight.fill_(1.0)
948
949         self.reset_inner_loss()
950
951     def forward(self, bs):
952         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
953
954         # To make the code simpler in the Caterpillar layer, we pad
955         # here. It's unclear if/how much it hurts computationaly by
956         # increasing the sequence length for the other layers
957
958         if self.caterpillar_length > 0:
959             original_nb = bs.nb
960             if bs.nb % self.caterpillar_length > 0:
961                 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
962
963             bs = BracketedSequence(
964                 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
965                 bs.first + self.caterpillar_length,
966                 bs.nb,
967                 bs.init_cache,
968             )
969
970         bs = self.embedding(bs)
971         bs = self.trunk(bs)
972         bs = self.readout(bs)
973
974         if self.caterpillar_length > 0:
975             bs = BracketedSequence(
976                 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
977                 bs.first - self.caterpillar_length,
978                 original_nb,
979                 bs.init_cache,
980             )
981
982         return bs
983
984     # ar_mask is a tensor with 0s and 1s, of same shape as input, with
985     # 1s where tokens should be generated. The others are kept
986     # unchanged.
987
988     def masked_inplace_autoregression(
989         self,
990         input_src,
991         ar_mask_src,
992         forbidden_tokens=None,
993         deterministic_synthesis=False,
994     ):
995         input = input_src.to(self.readout.f.weight.device)
996         ar_mask = ar_mask_src.to(self.readout.f.weight.device)
997         to_generate = (ar_mask.sum(0) > 0).nonzero()
998         if to_generate.min() > 0:
999             self(
1000                 BracketedSequence(input, 0, to_generate.min(), True)
1001             )  # Needed to initialize the model's cache
1002         for s in range(to_generate.min(), to_generate.max() + 1):
1003             output = self(BracketedSequence(input, s, 1, s == 0)).x
1004             logits = output[:, s]
1005             if forbidden_tokens is not None:
1006                 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
1007             if deterministic_synthesis:
1008                 t_next = logits.argmax(1)
1009             else:
1010                 dist = torch.distributions.categorical.Categorical(logits=logits)
1011                 t_next = dist.sample()
1012             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
1013
1014         input_src.copy_(input)
1015
1016     def reset_inner_loss(self):
1017         for m in self.modules():
1018             if m is not self and hasattr(m, "reset_inner_loss"):
1019                 m.reset_inner_loss()
1020
1021     def get_inner_loss(self):
1022         l = torch.tensor([0.0], device=self.readout.f.weight.device)
1023         for m in self.modules():
1024             if m is not self and hasattr(m, "get_inner_loss"):
1025                 l += m.get_inner_loss()
1026         return l
1027
1028     def record_attention(self, v=True):
1029         for m in self.modules():
1030             if isinstance(m, QKVAttention):
1031                 m.record_attention = v
1032
1033     def retrieve_attention(self):
1034         a = []
1035         for m in self.modules():
1036             if isinstance(m, QKVAttention):
1037                 a.append(m.a)
1038         return a
1039
1040
1041 ######################################################################
1042
1043 if __name__ == "__main__":
1044     print("Basic check.")
1045
1046     m = Caterpillar(
1047         dim_model=4,
1048         dim_qk=3,
1049         dim_v=7,
1050         nb_heads=1,
1051         caterpillar_length=7,
1052         caterpillar_height=3,
1053         attention_dropout=0.0,
1054     )
1055
1056     m.reset_inner_loss()
1057     x = torch.randn(1, 21 + 2 * 7, 4)
1058     y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1059     y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
1060     y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
1061     y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
1062     print((y1 - y2).abs().max())
1063     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
1064     exit(0)
1065
1066     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1067
1068     vocabulary_size = 128
1069     x = torch.randint(vocabulary_size, (6, 1024))
1070
1071     model = MyGPT(
1072         vocabulary_size=vocabulary_size,
1073         dim_model=512,
1074         dim_keys=64,
1075         dim_hidden=2048,
1076         nb_heads=8,
1077         nb_lines=128,
1078         nb_blocks=12,
1079         dropout=0.1,
1080         causal=True,
1081     )
1082
1083     x = x.to(device)
1084     model.to(device)
1085
1086     import time, sys
1087
1088     # import torchvision.models as models
1089     # from torch.profiler import profile, record_function, ProfilerActivity
1090
1091     # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
1092     # with record_function("model_inference"):
1093
1094     model.eval()
1095     for i in range(3):
1096         start_time = time.perf_counter()
1097         for k in range(10):
1098             model(BracketedSequence(x))
1099         duration = time.perf_counter() - start_time
1100         print(duration)
1101         sys.stdout.flush()
1102
1103     # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
1104     # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1105
1106     # print("##############################################################")
1107     # y2 = torch.randn_like(y1)
1108     # for s in range(x.size(1)):
1109     # z = model(BracketedSequence(x, s, 1))
1110     # y2[:, s : s + 1] = z.slice()
1111
1112     # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
1113
1114 ######################################################################