X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=492a9bb96872e93f99ea9d9609ba64fe557c57fa;hb=e3d5af800ccd197580265709c4499bf281beecb8;hp=a27b99e8dd47eb14696257fb1d814c8e33dd49cb;hpb=64dc96ddfa84511ba07d1929481e93e864735409;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index a27b99e..492a9bb 100755 --- a/mygpt.py +++ b/mygpt.py @@ -597,36 +597,14 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # warnings.warn("softmax gating", RuntimeWarning) + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - # G = ( - # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] - # ).softmax(dim=2) + G = G / G.sum(1, keepdim=True).clamp(min=1) ###################################################################### - # The "flashbacks" - - if self.training and self.proba_gate_dropout > 0.0: - # This is a better implementation of "flashbacks". - - # G is NxHxExT where e is the caterpillar's row. - - warnings.warn("gate dropout", RuntimeWarning) - - kill = ( - torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - ).float() - - alpha = G / (1 - self.proba_gate_dropout) - - G = alpha * (1 - kill) def recurrence(G, V, K): - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - # We prepare the arguments for the parallel scan A = 1 - G.sum(1) @@ -663,6 +641,26 @@ class Caterpillar(nn.Module): next_V, next_K = recurrence(G, V, K) + if self.training and self.proba_gate_dropout > 0.0: + # G is NxHxRxT where r is the caterpillar's row. + + warnings.warn("gate dropout", RuntimeWarning) + + kill = ( + torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + ).float() + + mask = 1 - kill + + masked_next_V, masked_next_K = recurrence(G * mask, V, K) + + next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.proba_gate_dropout + ) + next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.proba_gate_dropout + ) + self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K