OCDC
[mygpt.git] / mygpt.py
index 43711b3..7c4e06d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module):
         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        return x + pe # Let broadcasting to its job
 
 ##############################