From d2c145b4306d5c36094618ff7e7323c5d083e1df Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 6 Nov 2020 16:20:11 +0100 Subject: [PATCH] Update. --- attentiontoy1d.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 92d90cf..e82894e 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -76,7 +76,7 @@ seq_width_min, seq_width_max = 5.0, 11.0 seq_length = 100 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3): - st = torch.arange(seq_length).float() + st = torch.arange(seq_length, device = device).float() st = st[None, :, None] tr = tr[:, None, :, :] bx = bx[:, None, :, :] @@ -86,7 +86,6 @@ def positions_to_sequences(tr = None, bx = None, noise_level = 0.3): x = torch.cat((xtr, xbx), 2) - # u = x.sign() u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1) collisions = (u.sum(2) > 1).max(1).values @@ -100,12 +99,12 @@ def generate_sequences(nb): # Position / height / width - tr = torch.empty(nb, 2, 3) + tr = torch.empty(nb, 2, 3, device = device) tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2) tr[:, :, 1].uniform_(seq_height_min, seq_height_max) tr[:, :, 2].uniform_(seq_width_min, seq_width_max) - bx = torch.empty(nb, 2, 3) + bx = torch.empty(nb, 2, 3, device = device) bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2) bx[:, :, 1].uniform_(seq_height_min, seq_height_max) bx[:, :, 2].uniform_(seq_width_min, seq_width_max) @@ -169,10 +168,10 @@ def save_sequence_images(filename, sequences, tr = None, bx = None): delta = -1. if tr is not None: - ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False) + ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False) if bx is not None: - ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False) + ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False) fig.savefig(filename, bbox_inches='tight') -- 2.20.1