X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=attentiontoy1d.py;h=92d90cf79bb62ac618b1bd7590ee8f8fb498cba8;hp=d7f06fe0b587ba8f08dbfdda93ca58728a955f84;hb=04aff2fbb201d7987957203b9bb7b667f46c4fe9;hpb=4d0e56bee81c535293367628dd73cbf993d0690a diff --git a/attentiontoy1d.py b/attentiontoy1d.py index d7f06fe..92d90cf 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -31,8 +31,15 @@ parser.add_argument('--positional_encoding', help = 'Provide a positional encoding', action='store_true', default=False) +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + args = parser.parse_args() +if args.seed >= 0: + torch.manual_seed(args.seed) + ###################################################################### label='' @@ -62,8 +69,6 @@ if torch.cuda.is_available(): else: device = torch.device('cpu') -torch.manual_seed(1) - ###################################################################### seq_height_min, seq_height_max = 1.0, 25.0