Update.
[picoclvr.git] / graph.py
index a819283..c286388 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -161,7 +161,7 @@ if __name__ == "__main__":
         nb_heads=2,
         nb_blocks=5,
         dropout=0.1,
-        causal=True,
+        #causal=True,
     )
 
     model.eval()
@@ -171,6 +171,8 @@ if __name__ == "__main__":
 
     attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
 
+    
+
     # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
     # for a in attention_matrices: a=a/a.sum(-1,keepdim=True)