import ffutils
-from blanket import blanket
+# from blanket import blanket
# import memload
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
- V, K = blanket(V), blanket(K)
+ # V, K = blanket(V), blanket(K)
######################################################################
# Compute the recurrent state
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
- Q = blanket(Q)
+ # Q = blanket(Q)
# We build tensors NxHxTxRxL where N is the sample index, H
# the head, T the time, R the row in the caterpillar, and L
# Compute the final output
- Y = blanket(Y)
+ # Y = blanket(Y)
self.cache_Y[:, t0:t1] = Y @ self.w_O