3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # This is an implementation from scratch of a "GPT", that is a model
9 # composed of several causal self-attention blocks. It is equipped
10 # with a caching mechanism for keys and values to avoid a O(N^3) cost
11 # for auto-regression.
18 from torch.nn import functional as F
19 from functorch.dim import dims
25 ######################################################################
27 # A BracketedSequence is a BxTx... tensor with a first and a nb time
30 # Modules able to process it expect that they will have to process a
31 # first bracket starting at t=0, followed by a succession of brackets
32 # that move forward in time, do not overlap, and cover the axis T with
35 # Although it is more general, for a classical prompt-conditioned
36 # auto-regressive process it will be a first bracket starting at 0 and
37 # of arbitrary length for the "prompt", followed by brackets of length
38 # 1 for the successive tokens.
40 # Modules able to process brackets may implement a cache that is
41 # resetted when the input bracket starts at t=0
44 class BracketedSequence:
45 def __init__(self, x, first=None, nb=None, init_cache=None):
47 assert (first is None and nb is None and init_cache is None) or (
48 first is not None and nb is not None and init_cache is not None
51 self.first = 0 if first is None else first
52 self.nb = x.size(1) if nb is None else nb
53 self.init_cache = True if init_cache is None else init_cache
56 return self.x[:, self.first : self.first + self.nb]
59 return self.first == 0 and self.nb == self.x.size(1)
62 ######################################################################
65 class CacheWrapper(nn.Module):
66 def __init__(self, *f):
68 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
70 def forward(self, bs):
72 y = self.f(bs.slice())
73 self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
74 self.cache_y[:, bs.first : bs.first + bs.nb] = y
76 assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
77 assert bs.first + bs.nb <= self.cache_y.size(1)
78 self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
80 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
83 ##############################
86 class WithResidual(nn.Module):
87 def __init__(self, *f):
89 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
91 def forward(self, bs):
92 return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
95 ##############################
98 class AddPositionalEncoding(nn.Module):
99 def __init__(self, len_max):
101 self.len_max = len_max
103 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
105 def forward(self, bs):
107 t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
110 j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
115 t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
117 self.cache_y = bs.x.new(bs.x.size())
119 self.cache_y[:, bs.first : bs.first + bs.nb] = (
120 bs.slice() + self.pe[bs.first : bs.first + bs.nb]
123 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
129 # X is /.../xTxD A is /.../xT Y_init is /.../xD
132 def pscan_dim(A, X, Y_init, dim=-2):
134 a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
136 A = A.reshape(a, T, *s[dim + 1 : -1])
137 X = X.reshape(a, T, *s[dim + 1 : -1], -1)
140 Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
142 Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
144 Y = pscan.pscan(A, X, Y_init).reshape(s)
149 def pscan_shape(A, X, Y_init):
151 A = A.reshape(-1, s[-2])
152 X = X.reshape(-1, s[-2], s[-1])
155 Y_init = X.new_zeros(X.size(0), s[-1])
157 Y_init = Y_init.reshape(-1, s[-1])
159 Y = pscan.pscan(A, X, Y_init).reshape(s)
164 def nsum_shape(X, Y_init):
166 X = X.reshape(-1, s[-2], s[-1]) # ntd
168 Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
171 for k in range(X.size(1)):
173 Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
176 return torch.cat(result, dim=1).reshape(s)
179 ##############################
182 class DumbRec(nn.Module):
190 attention_dropout=0.0,
196 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
198 self.nb_lines = nb_lines
199 self.attention_dropout = attention_dropout
201 self.k_star = randw(nb_lines, dim_qk)
203 self.w_qw = randw(nb_heads, dim_qk, dim_in)
204 self.w_qr = randw(nb_heads, dim_qk, dim_in)
205 # self.w_k = randw(nb_heads, dim_qk, dim_in)
206 self.w_v = randw(nb_heads, dim_v, dim_in)
207 self.w_o = randw(dim_v * nb_heads, dim_in)
209 def reset_inner_loss(self):
210 self.acc_attention = 0
213 def get_inner_loss(self):
214 warnings.warn("l2 regularization", RuntimeWarning)
215 return (self.acc_attention / self.acc_nb).pow(2).sum()
216 # return torch.tensor([0], device=self.w_qw.device)
218 def forward(self, bs):
219 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
222 self.rec_v = x_q.new_zeros(
223 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
225 # self.rec_k = x_q.new_zeros(
226 # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
228 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
230 ######################################################################
233 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
235 warnings.warn("rotating key barrel", RuntimeWarning)
236 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
237 t_barrel = torch.arange(t0, t1, device=k_star.device)
238 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
240 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
242 k_star = k_star[l_barrel, t_barrel]
244 ######################################################################
245 # Compute the recurrent state
247 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
249 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
250 # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
256 ) / math.sqrt(self.w_qw.size(1))
258 aw = aw.softmax(dim=2) # nhlt
261 self.acc_attention += aw.sum(dim=(0, 1, 3))
262 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
264 aw = F.dropout(aw, self.attention_dropout, self.training)
266 A = 1 - aw.sum(dim=1) # nlt
268 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
269 # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
275 V0 = self.rec_v[:, :, t0 - 1]
276 # K0 = self.rec_k[:, :, t0 - 1]
278 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
279 # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
281 ######################################################################
282 # compute the readout
284 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
289 # self.rec_k[:, :, t0:t1],
291 ) / math.sqrt(self.w_qr.size(1))
293 ar = ar.softmax(dim=2) # nhlt
295 ar = F.dropout(ar, self.attention_dropout, self.training)
300 self.rec_v[:, :, t0:t1],
303 self.cache_y[:, t0:t1] = y @ self.w_o
305 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
308 ##############################
311 class KVRec(nn.Module):
319 attention_dropout=0.0,
325 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
327 self.nb_lines = nb_lines
328 self.attention_dropout = attention_dropout
330 self.k_star = randw(nb_lines, dim_qk)
332 self.w_qw = randw(nb_heads, dim_qk, dim_in)
333 self.w_qr = randw(nb_heads, dim_qk, dim_in)
334 self.w_k = randw(nb_heads, dim_qk, dim_in)
335 self.w_v = randw(nb_heads, dim_v, dim_in)
336 self.w_o = randw(dim_v * nb_heads, dim_in)
338 def reset_inner_loss(self):
339 self.acc_attention = 0
342 def get_inner_loss(self):
343 warnings.warn("l2 regularization", RuntimeWarning)
344 return (self.acc_attention / self.acc_nb).pow(2).sum()
345 # return torch.tensor([0], device=self.w_qw.device)
346 # warnings.warn("side regularization", RuntimeWarning)
348 # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
350 # return torch.tensor([0], device=self.w_qw.device)
352 def forward(self, bs):
353 x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
355 # n,h,l,t,d = dims(5)
358 self.rec_v = x_q.new_zeros(
359 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
361 self.rec_k = x_q.new_zeros(
362 x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
364 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
366 ######################################################################
369 k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
371 warnings.warn("rotating key barrel", RuntimeWarning)
372 k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
373 t_barrel = torch.arange(t0, t1, device=k_star.device)
374 t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
376 torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
378 k_star = k_star[l_barrel, t_barrel]
380 ######################################################################
381 # Compute the recurrent state
383 qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
385 v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
386 k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
392 ) / math.sqrt(self.w_qw.size(1))
394 aw = aw.softmax(dim=2) # nhlt
397 # We want all the memory lines to be used similarly
398 self.acc_attention += aw.sum(dim=(0, 1, 3)) # Sum accross NxHx_xT
399 self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
401 aw = F.dropout(aw, self.attention_dropout, self.training)
403 A = 1 - aw.sum(dim=1) # nlt
405 V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
406 K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
412 V0 = self.rec_v[:, :, t0 - 1]
413 K0 = self.rec_k[:, :, t0 - 1]
415 self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
416 self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
418 ######################################################################
419 # compute the readout
421 qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
426 self.rec_k[:, :, t0:t1],
427 ) / math.sqrt(self.w_qr.size(1))
429 ar = ar.softmax(dim=2) # nhlt
431 ar = F.dropout(ar, self.attention_dropout, self.training)
436 self.rec_v[:, :, t0:t1],
439 self.cache_y[:, t0:t1] = y @ self.w_o
441 return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
444 ##############################
447 def moving_window(x, dim, win_dim, win_size):
448 size, stride = x.size(), x.stride()
449 size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
450 size = size[:win_dim] + (win_size,) + size[win_dim:]
451 stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
453 return x.as_strided(size=size, stride=stride)
456 ##############################
459 class Caterpillar(nn.Module):
468 attention_dropout=0.0,
473 warnings.warn("Caterpillar", RuntimeWarning)
476 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
478 self.caterpillar_length = caterpillar_length
479 self.caterpillar_height = caterpillar_height
480 self.attention_dropout = attention_dropout
482 self.w_G = randw(nb_heads, caterpillar_height, dim_in)
483 self.b_G = nn.Parameter(
485 (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
489 self.w_K = randw(nb_heads, dim_qk, dim_in)
490 self.w_V = randw(nb_heads, dim_v, dim_in)
491 self.w_Q = randw(nb_heads, dim_qk, dim_in)
492 self.w_O = randw(dim_v * nb_heads, dim_in)
494 self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
495 self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
497 def reset_inner_loss(self):
498 self.acc_attention = 0
501 def get_inner_loss(self):
502 # warnings.warn("l2 regularization", RuntimeWarning)
503 # return (self.acc_attention / self.acc_nb).pow(2).sum()
504 return torch.tensor([0], device=self.w_Q.device)
506 def forward(self, bs):
507 # Dimensions to make the source a bit clearer, that's needed
509 X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
513 DV = self.w_V.size(1)
514 DK = self.w_K.size(1)
515 Dout = self.w_O.size(1)
516 CH = self.caterpillar_height
517 CL = self.caterpillar_length
520 t0 >= CL and (t1 - t0) % CL == 0
521 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
524 self.rec_V = X.new_zeros(N, CH, T, DV)
525 self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
526 self.rec_K = X.new_zeros(N, CH, T, DK)
527 self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
528 self.cache_Y = X.new_zeros(N, T, Dout)
530 ######################################################################
531 # Compute the recurrent state
534 torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
537 V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
538 K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
541 gated_V = torch.einsum("nhet,nhtd->netd", G, V)
542 gated_K = torch.einsum("nhet,nhtd->netd", G, K)
544 init_rec_V = self.rec_V[:, :, t0 - CL : t0]
545 init_rec_K = self.rec_K[:, :, t0 - CL : t0]
547 A = A.unflatten(2, (-1, CL))
548 gated_V = gated_V.unflatten(2, (-1, CL))
549 gated_K = gated_K.unflatten(2, (-1, CL))
551 next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
552 next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
554 self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
555 self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
557 ######################################################################
558 # compute the readout
560 Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
563 self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
567 self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
576 ar = ar.flatten(3).softmax(dim=3).view(ar.size())
578 ar = F.dropout(ar, self.attention_dropout, self.training)
586 self.cache_Y[:, t0:t1] = Y @ self.w_O
588 return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
591 ##############################
594 class QKVAttention(nn.Module):
602 attention_dropout=0.0,
607 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
610 self.attention_dropout = attention_dropout
611 self.record_attention = False
613 self.w_q = randw(nb_heads, dim_qk, dim_in)
614 self.w_k = randw(nb_heads, dim_qk, dim_in)
615 self.w_v = randw(nb_heads, dim_v, dim_in)
616 self.w_o = randw(dim_v * nb_heads, dim_in)
618 def forward(self, bs):
622 self.causal or bs.complete()
623 ), "Partial evaluation is only possible for causal models"
626 self.cache_k = x_q.new_zeros(
627 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
629 self.cache_v = x_q.new_zeros(
630 x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
632 self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
634 q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
636 self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
637 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
639 self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
640 "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
644 "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
645 ) / math.sqrt(self.w_q.size(1))
649 self.cache_attzero = (
650 torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
651 < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
655 :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
662 if self.record_attention:
665 a = F.dropout(a, self.attention_dropout, self.training)
668 "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
671 self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
673 return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
676 ##############################
679 class MyGPT(nn.Module):
689 caterpillar_height=None,
694 attention_layer="kvrec",
698 assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
700 if attention_layer == "caterpillar":
701 assert nb_lines % caterpillar_height == 0
702 self.caterpillar_length = nb_lines // caterpillar_height
703 self.caterpillar_height = caterpillar_height
705 self.caterpillar_length = -1
706 self.caterpillar_height = -1
708 assert dim_model % nb_heads == 0
710 self.embedding = nn.Sequential(
711 CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
712 AddPositionalEncoding(len_max),
718 if attention_layer == "mha":
722 dim_v=dim_model // nb_heads,
725 attention_dropout=dropout,
727 elif attention_layer == "dumbrec":
734 attention_dropout=dropout,
736 elif attention_layer == "kvrec":
743 attention_dropout=dropout,
745 elif attention_layer == "caterpillar":
751 caterpillar_length=self.caterpillar_length,
752 caterpillar_height=self.caterpillar_height,
753 attention_dropout=dropout,
756 raise ValueError(f"Unknown attention type {attention_layer}.")
758 for b in range(nb_blocks):
761 CacheWrapper(nn.LayerNorm((dim_model,))),
766 nn.LayerNorm((dim_model,)),
767 nn.Linear(in_features=dim_model, out_features=dim_hidden),
769 nn.Linear(in_features=dim_hidden, out_features=dim_model),
775 self.trunk = nn.Sequential(*trunk_blocks)
777 self.readout = CacheWrapper(
778 nn.Linear(in_features=dim_model, out_features=vocabulary_size)
781 with torch.no_grad():
782 for m in self.modules():
783 if isinstance(m, nn.Embedding):
784 m.weight.normal_(mean=0, std=2e-2)
785 elif isinstance(m, nn.LayerNorm):
789 self.reset_inner_loss()
791 def forward(self, bs):
792 bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
794 # To make the code simpler in the Caterpillar layer, we pad
795 # here. It's unclear if/how much it hurts computationaly by
796 # increasing the sequence length for the other layers
798 if self.caterpillar_length > 0:
800 if bs.nb % self.caterpillar_length > 0:
801 bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
803 bs = BracketedSequence(
804 F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
805 bs.first + self.caterpillar_length,
810 bs = self.embedding(bs)
812 bs = self.readout(bs)
814 if self.caterpillar_length > 0:
815 bs = BracketedSequence(
816 F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
817 bs.first - self.caterpillar_length,
824 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
825 # 1s where tokens should be generated. The others are kept
828 def masked_inplace_autoregression(
832 forbidden_tokens=None,
833 deterministic_synthesis=False,
835 input = input_src.to(self.readout.f.weight.device)
836 ar_mask = ar_mask_src.to(self.readout.f.weight.device)
837 to_generate = (ar_mask.sum(0) > 0).nonzero()
838 if to_generate.min() > 0:
840 BracketedSequence(input, 0, to_generate.min(), True)
841 ) # Needed to initialize the model's cache
842 for s in range(to_generate.min(), to_generate.max() + 1):
843 output = self(BracketedSequence(input, s, 1, s == 0)).x
844 logits = output[:, s]
845 if forbidden_tokens is not None:
846 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
847 if deterministic_synthesis:
848 t_next = logits.argmax(1)
850 dist = torch.distributions.categorical.Categorical(logits=logits)
851 t_next = dist.sample()
852 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
854 input_src.copy_(input)
856 def reset_inner_loss(self):
857 for m in self.modules():
858 if m is not self and hasattr(m, "reset_inner_loss"):
861 def get_inner_loss(self):
862 l = torch.tensor([0.0], device=self.readout.f.weight.device)
863 for m in self.modules():
864 if m is not self and hasattr(m, "get_inner_loss"):
865 l += m.get_inner_loss()
868 def record_attention(self, v=True):
869 for m in self.modules():
870 if isinstance(m, QKVAttention):
871 m.record_attention = v
873 def retrieve_attention(self):
875 for m in self.modules():
876 if isinstance(m, QKVAttention):
881 ######################################################################
883 if __name__ == "__main__":
884 print("Basic check.")
891 caterpillar_length=7,
892 caterpillar_height=3,
893 attention_dropout=0.0,
897 x = torch.randn(1, 21 + 2 * 7, 4)
898 y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
899 y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
900 y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
901 y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
902 print((y1 - y2).abs().max())
903 print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
906 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
908 vocabulary_size = 128
909 x = torch.randint(vocabulary_size, (6, 1024))
912 vocabulary_size=vocabulary_size,
928 # import torchvision.models as models
929 # from torch.profiler import profile, record_function, ProfilerActivity
931 # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
932 # with record_function("model_inference"):
936 start_time = time.perf_counter()
938 model(BracketedSequence(x))
939 duration = time.perf_counter() - start_time
943 # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
944 # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
946 # print("##############################################################")
947 # y2 = torch.randn_like(y1)
948 # for s in range(x.size(1)):
949 # z = model(BracketedSequence(x, s, 1))
950 # y2[:, s : s + 1] = z.slice()
952 # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
954 ######################################################################