From: Francois Fleuret Date: Wed, 27 Jul 2022 14:42:47 +0000 (+0200) Subject: OCDC X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=cd1cc80f711ca1f7188cc9854f18231e02470eba OCDC --- diff --git a/mygpt.py b/mygpt.py index 43711b3..7c4e06d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 - return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + return x + pe # Let broadcasting to its job ##############################