From 8492656cf0cc5de4f7e2c4aa8ccb717193293b40 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 22:59:10 +0200 Subject: [PATCH] Update. --- graph.py | 4 +++- mygpt.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/graph.py b/graph.py index a819283..c286388 100755 --- a/graph.py +++ b/graph.py @@ -161,7 +161,7 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=5, dropout=0.1, - causal=True, + #causal=True, ) model.eval() @@ -171,6 +171,8 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] # for a in attention_matrices: a=a/a.sum(-1,keepdim=True) diff --git a/mygpt.py b/mygpt.py index 0400b48..0cf70e0 100755 --- a/mygpt.py +++ b/mygpt.py @@ -46,7 +46,7 @@ class BracketedSequence: return self.x[:, self.first : self.first + self.nb] def complete(self): - return self.first == 0 and self.nb == x.size(1) + return self.first == 0 and self.nb == self.x.size(1) ###################################################################### -- 2.20.1