OCD update.
[pytorch.git] / attentiontoy1d.py
index 6540a0f..92d90cf 100755 (executable)
@@ -31,8 +31,15 @@ parser.add_argument('--positional_encoding',
                     help = 'Provide a positional encoding',
                     action='store_true', default=False)
 
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
 args = parser.parse_args()
 
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
 ######################################################################
 
 label=''
@@ -62,8 +69,6 @@ if torch.cuda.is_available():
 else:
     device = torch.device('cpu')
 
-torch.manual_seed(1)
-
 ######################################################################
 
 seq_height_min, seq_height_max = 1.0, 25.0
@@ -309,8 +314,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0),
 test_outputs = model((test_input - mu) / std).detach()
 
 if args.with_attention:
-    x = model[0:4]((test_input - mu) / std)
-    test_A = model[4].attention(x)
+    k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+    x = model[0:k]((test_input - mu) / std)
+    test_A = model[k].attention(x)
     test_A = test_A.detach().to('cpu')
 
 test_input = test_input.detach().to('cpu')