OCDC
authorFrancois Fleuret <francois@fleuret.org>
Wed, 27 Jul 2022 14:42:47 +0000 (16:42 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 27 Jul 2022 14:42:47 +0000 (16:42 +0200)
mygpt.py

index 43711b3..7c4e06d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module):
         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        return x + pe # Let broadcasting to its job
 
 ##############################