- # From Vaswani et al 2018
- # PE_{t,2i} = sin(t/(L^{2i/D}))
- # PE_{t,2i+1} = cos(t/(L^{2i/D}))
+ # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
def forward(self, x):
t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
def forward(self, x):
t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
def forward(self, x_q, x_kv = None):
if x_kv is None: x_kv = x_q
def forward(self, x_q, x_kv = None):
if x_kv is None: x_kv = x_q
q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
if self.causal:
mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
< torch.arange(a.size(3), device = q.device)[None, None, None, :]
a = a.masked_fill(mask, float('-inf'))
if self.causal:
mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
< torch.arange(a.size(3), device = q.device)[None, None, None, :]
a = a.masked_fill(mask, float('-inf'))
- y = torch.einsum('nhts,nhsd->nthd', a, v)
- y = y.flatten(2) @ self.w_o
+ y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2)
+
+ y = y @ self.w_o
def __init__(self,
vocabulary_size,
dim_model, dim_keys, dim_hidden,
def __init__(self,
vocabulary_size,
dim_model, dim_keys, dim_hidden,
self.embedding = nn.Sequential(
nn.Embedding(vocabulary_size, dim_model),
nn.Dropout(dropout),
self.embedding = nn.Sequential(
nn.Embedding(vocabulary_size, dim_model),
nn.Dropout(dropout),
nn.Linear(in_features = dim_model, out_features = dim_hidden),
nn.ReLU(),
nn.Linear(in_features = dim_hidden, out_features = dim_model),
nn.Linear(in_features = dim_model, out_features = dim_hidden),
nn.ReLU(),
nn.Linear(in_features = dim_hidden, out_features = dim_model),
- dim_model = 16, dim_keys = 50, dim_hidden = 100,
+ dim_model = 18, dim_keys = 50, dim_hidden = 100,