projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
9a02bcd
..
67c5cfd
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-21,7
+21,7
@@
from torch.nn import functional as F
import ffutils
import ffutils
-from blanket import blanket
+
#
from blanket import blanket
# import memload
# import memload
@@
-569,7
+569,7
@@
class Caterpillar(nn.Module):
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
- V, K = blanket(V), blanket(K)
+
#
V, K = blanket(V), blanket(K)
######################################################################
# Compute the recurrent state
######################################################################
# Compute the recurrent state
@@
-673,7
+673,7
@@
class Caterpillar(nn.Module):
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
- Q = blanket(Q)
+
#
Q = blanket(Q)
# We build tensors NxHxTxRxL where N is the sample index, H
# the head, T the time, R the row in the caterpillar, and L
# We build tensors NxHxTxRxL where N is the sample index, H
# the head, T the time, R the row in the caterpillar, and L
@@
-712,7
+712,7
@@
class Caterpillar(nn.Module):
# Compute the final output
# Compute the final output
- Y = blanket(Y)
+
#
Y = blanket(Y)
self.cache_Y[:, t0:t1] = Y @ self.w_O
self.cache_Y[:, t0:t1] = Y @ self.w_O