Update.
[mygptrnn.git] / mygpt.py
index 33c6fee..c061eb4 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -441,6 +441,11 @@ class KVRec(nn.Module):
 ##############################
 
 
+# Returns a tensor with an additional index at rank win_dim, that move
+# along the same dimension as dim, on a domain {0...win_size-1}, and
+# dim is restricted on a domain reduced by win_size-1 values.
+
+
 def moving_window(x, dim, win_dim, win_size):
     size, stride = x.size(), x.stride()
     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
@@ -540,6 +545,8 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        G = F.dropout(G, self.attention_dropout, self.training)
+
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)