From: Francois Fleuret Date: Fri, 22 May 2020 11:22:54 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=4d0e56bee81c535293367628dd73cbf993d0690a Update. --- diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 6540a0f..d7f06fe 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -309,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu')