OCD update.
[pytorch.git] / attentiontoy1d.py
index ad0c0b1..92d90cf 100755 (executable)
@@ -10,9 +10,11 @@ import torch, math, sys, argparse
 from torch import nn
 from torch.nn import functional as F
 
+import matplotlib.pyplot as plt
+
 ######################################################################
 
-parser = argparse.ArgumentParser(description='Toy RNN.')
+parser = argparse.ArgumentParser(description='Toy attention model.')
 
 parser.add_argument('--nb_epochs',
                     type = int, default = 250)
@@ -29,8 +31,15 @@ parser.add_argument('--positional_encoding',
                     help = 'Provide a positional encoding',
                     action='store_true', default=False)
 
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
 args = parser.parse_args()
 
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
 ######################################################################
 
 label=''
@@ -60,8 +69,6 @@ if torch.cuda.is_available():
 else:
     device = torch.device('cpu')
 
-torch.manual_seed(1)
-
 ######################################################################
 
 seq_height_min, seq_height_max = 1.0, 25.0
@@ -146,9 +153,6 @@ def generate_sequences(nb):
 
 ######################################################################
 
-import matplotlib.pyplot as plt
-import matplotlib.collections as mc
-
 def save_sequence_images(filename, sequences, tr = None, bx = None):
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
@@ -310,8 +314,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
 test_outputs = model((test_input - mu) / std).detach()
 
 if args.with_attention:
-    x = model[0:4]((test_input - mu) / std)
-    test_A = model[4].attention(x)
+    k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+    x = model[0:k]((test_input - mu) / std)
+    test_A = model[k].attention(x)
     test_A = test_A.detach().to('cpu')
 
 test_input = test_input.detach().to('cpu')