return self.x[:, self.first : self.first + self.nb]
def complete(self):
- return self.first == 0 and self.nb == x.size(1)
+ return self.first == 0 and self.nb == self.x.size(1)
######################################################################
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.record_attention:
- self.a = a
-
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(