Update.
[mygptrnn.git] / mygpt.py
index 9a02bcd..67c5cfd 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -21,7 +21,7 @@ from torch.nn import functional as F
 
 import ffutils
 
-from blanket import blanket
+from blanket import blanket
 
 # import memload
 
@@ -569,7 +569,7 @@ class Caterpillar(nn.Module):
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
 
-        V, K = blanket(V), blanket(K)
+        V, K = blanket(V), blanket(K)
 
         ######################################################################
         # Compute the recurrent state
@@ -673,7 +673,7 @@ class Caterpillar(nn.Module):
 
         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
 
-        Q = blanket(Q)
+        Q = blanket(Q)
 
         # We build tensors NxHxTxRxL where N is the sample index, H
         # the head, T the time, R the row in the caterpillar, and L
@@ -712,7 +712,7 @@ class Caterpillar(nn.Module):
 
         # Compute the final output
 
-        Y = blanket(Y)
+        Y = blanket(Y)
 
         self.cache_Y[:, t0:t1] = Y @ self.w_O