projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
e282628
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Fri, 6 Nov 2020 15:20:11 +0000
(16:20 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Fri, 6 Nov 2020 15:20:11 +0000
(16:20 +0100)
attentiontoy1d.py
patch
|
blob
|
history
diff --git
a/attentiontoy1d.py
b/attentiontoy1d.py
index
92d90cf
..
e82894e
100755
(executable)
--- a/
attentiontoy1d.py
+++ b/
attentiontoy1d.py
@@
-76,7
+76,7
@@
seq_width_min, seq_width_max = 5.0, 11.0
seq_length = 100
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
seq_length = 100
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
- st = torch.arange(seq_length).float()
+ st = torch.arange(seq_length
, device = device
).float()
st = st[None, :, None]
tr = tr[:, None, :, :]
bx = bx[:, None, :, :]
st = st[None, :, None]
tr = tr[:, None, :, :]
bx = bx[:, None, :, :]
@@
-86,7
+86,6
@@
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
x = torch.cat((xtr, xbx), 2)
x = torch.cat((xtr, xbx), 2)
- # u = x.sign()
u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
collisions = (u.sum(2) > 1).max(1).values
u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
collisions = (u.sum(2) > 1).max(1).values
@@
-100,12
+99,12
@@
def generate_sequences(nb):
# Position / height / width
# Position / height / width
- tr = torch.empty(nb, 2, 3)
+ tr = torch.empty(nb, 2, 3
, device = device
)
tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
- bx = torch.empty(nb, 2, 3)
+ bx = torch.empty(nb, 2, 3
, device = device
)
bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
@@
-169,10
+168,10
@@
def save_sequence_images(filename, sequences, tr = None, bx = None):
delta = -1.
if tr is not None:
delta = -1.
if tr is not None:
- ax.scatter(t
est_tr[k, :, 0], torch.full((test_tr.size(1
),), delta), color = 'black', marker = '^', clip_on=False)
+ ax.scatter(t
r[:, 0].cpu(), torch.full((tr.size(0
),), delta), color = 'black', marker = '^', clip_on=False)
if bx is not None:
if bx is not None:
- ax.scatter(
test_bx[k, :, 0], torch.full((test_bx.size(1
),), delta), color = 'black', marker = 's', clip_on=False)
+ ax.scatter(
bx[:, 0].cpu(), torch.full((bx.size(0
),), delta), color = 'black', marker = 's', clip_on=False)
fig.savefig(filename, bbox_inches='tight')
fig.savefig(filename, bbox_inches='tight')