X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=4555b1e4de65517051d10e5249f0cba7cc496d35;hb=HEAD;hp=5ea4668203f0f4eaf9ccbc586f2d58a348fe0a3f;hpb=3602bfe2c4e1cd513759bf45cb83f8c2d914674b;p=beaver.git diff --git a/mygpt.py b/mygpt.py index 5ea4668..4555b1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,19 +14,6 @@ from torch.nn import functional as F ###################################################################### - -class WithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - bs.x = bs.x + self.f(bs).x - return bs - - -###################################################################### - # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -57,6 +44,19 @@ class BracketedSequence: ###################################################################### +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + bs.x = bs.x + self.f(bs).x + return bs + + +###################################################################### + + class CacheWrapper(nn.Module): def __init__(self, *f): super().__init__() @@ -85,22 +85,41 @@ class AddPositionalEncoding(nn.Module): # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - def forward(self, bs): + def forward(self, bs, order): # NxTxD, T if bs.first == 0: - t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ - :, None - ] - j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + t = ( + torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[ + :, None + ] + - 1 + ) + j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[ None, : ] k = j % 2 - self.pe = torch.sin( - t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + pe = ( + torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + .unsqueeze(0) + .expand(bs.x.size(0), -1, -1) + ) + + order_output = order + 1 + order_input = F.pad(order + 1, (1, -1)) + + pe_input = pe.gather( + 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1)) + ) + pe_output = pe.gather( + 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) ) + + self.pe = torch.cat((pe_input, pe_output), 2) self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] + bs.slice() + self.pe[:, bs.first : bs.first + bs.nb] ) bs.x = self.cache_y @@ -113,13 +132,27 @@ class AddPositionalEncoding(nn.Module): class QKVAttention(nn.Module): def __init__( - self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, + amm_generator=None, ): super().__init__() def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + if amm_generator is None: + self.amm_generator = ( + lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :] + ) + else: + self.amm_generator = amm_generator + self.causal = causal self.attention_dropout = attention_dropout @@ -156,10 +189,9 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[ + None, None, :, : + ] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -196,16 +228,16 @@ class MyGPT(nn.Module): causal=False, dropout=0.0, len_max=1e5, + amm_generator=None, ): - super().__init__() assert dim_model % nb_heads == 0 - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), + self.embedding = CacheWrapper( + nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout) ) + self.pe = AddPositionalEncoding(len_max) trunk_blocks = [] @@ -220,6 +252,7 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + amm_generator=amm_generator, ), ), WithResidual( @@ -247,18 +280,34 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, bs): - bs.x = F.pad(bs.x, (1, -1)) + def forward(self, bs, mode="standard", order=None): + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + if order is None: + order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as( + bs.x + ) bs = self.embedding(bs) - bs = self.trunk(bs) - bs = self.readout(bs) + bs = self.pe(bs, order) + + if mode == "standard": + bs = self.trunk(bs) + bs = self.readout(bs) + elif mode == "head": + bs = self.trunk(bs) + elif mode == "deep": + r = [] + for l in self.trunk: + bs = l(bs) + r += [bs.slice()] + bs = BracketedSequence(torch.cat(r, -1)) + else: + raise ValueError(f"{mode=}") return bs ###################################################################### if __name__ == "__main__": - print("Basic check.") vocabulary_size = 10