Update.
[pytorch.git] / attentiontoy1d.py
index cff8350..d7f06fe 100755 (executable)
@@ -1,18 +1,20 @@
 #!/usr/bin/env python
 
-# @XREMOTE_HOST: elk.fleuret.org
-# @XREMOTE_EXEC: /home/fleuret/conda/bin/python
-# @XREMOTE_PRE: killall -q -9 python || echo "Nothing killed"
-# @XREMOTE_GET: *.pdf *.log
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
 
 import torch, math, sys, argparse
 
 from torch import nn
 from torch.nn import functional as F
 
+import matplotlib.pyplot as plt
+
 ######################################################################
 
-parser = argparse.ArgumentParser(description='Toy RNN.')
+parser = argparse.ArgumentParser(description='Toy attention model.')
 
 parser.add_argument('--nb_epochs',
                     type = int, default = 250)
@@ -146,9 +148,6 @@ def generate_sequences(nb):
 
 ######################################################################
 
-import matplotlib.pyplot as plt
-import matplotlib.collections as mc
-
 def save_sequence_images(filename, sequences, tr = None, bx = None):
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
@@ -310,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
 test_outputs = model((test_input - mu) / std).detach()
 
 if args.with_attention:
-    x = model[0:4]((test_input - mu) / std)
-    test_A = model[4].attention(x)
+    k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+    x = model[0:k]((test_input - mu) / std)
+    test_A = model[k].attention(x)
     test_A = test_A.detach().to('cpu')
 
 test_input = test_input.detach().to('cpu')