from torch import nn
from torch.nn import functional as F
+import matplotlib.pyplot as plt
+
######################################################################
parser = argparse.ArgumentParser(description='Toy attention model.')
######################################################################
-import matplotlib.pyplot as plt
-import matplotlib.collections as mc
-
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
- x = model[0:4]((test_input - mu) / std)
- test_A = model[4].attention(x)
+ k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+ x = model[0:k]((test_input - mu) / std)
+ test_A = model[k].attention(x)
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')