X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=attentiontoy1d.py;h=d7f06fe0b587ba8f08dbfdda93ca58728a955f84;hp=6540a0f03bd36316bcd875a90058b3b831dff545;hb=4d0e56bee81c535293367628dd73cbf993d0690a;hpb=b27b7cc54f450bb5fe8c9ea2faf5e01d0082889a diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 6540a0f..d7f06fe 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -309,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu')