X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=b463340015b2814b1f2b71c8a0611b7270d499a7;hb=d74d7be5abef26c78d014bd179f2c52f81aca65b;hp=6540a0f03bd36316bcd875a90058b3b831dff545;hpb=b27b7cc54f450bb5fe8c9ea2faf5e01d0082889a;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 6540a0f..b463340 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -7,7 +7,7 @@ import torch, math, sys, argparse -from torch import nn +from torch import nn, einsum from torch.nn import functional as F import matplotlib.pyplot as plt @@ -31,8 +31,15 @@ parser.add_argument('--positional_encoding', help = 'Provide a positional encoding', action='store_true', default=False) +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + args = parser.parse_args() +if args.seed >= 0: + torch.manual_seed(args.seed) + ###################################################################### label='' @@ -62,8 +69,6 @@ if torch.cuda.is_available(): else: device = torch.device('cpu') -torch.manual_seed(1) - ###################################################################### seq_height_min, seq_height_max = 1.0, 25.0 @@ -71,7 +76,7 @@ seq_width_min, seq_width_max = 5.0, 11.0 seq_length = 100 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3): - st = torch.arange(seq_length).float() + st = torch.arange(seq_length, device = device).float() st = st[None, :, None] tr = tr[:, None, :, :] bx = bx[:, None, :, :] @@ -81,7 +86,6 @@ def positions_to_sequences(tr = None, bx = None, noise_level = 0.3): x = torch.cat((xtr, xbx), 2) - # u = x.sign() u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1) collisions = (u.sum(2) > 1).max(1).values @@ -95,12 +99,12 @@ def generate_sequences(nb): # Position / height / width - tr = torch.empty(nb, 2, 3) + tr = torch.empty(nb, 2, 3, device = device) tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2) tr[:, :, 1].uniform_(seq_height_min, seq_height_max) tr[:, :, 2].uniform_(seq_width_min, seq_width_max) - bx = torch.empty(nb, 2, 3) + bx = torch.empty(nb, 2, 3, device = device) bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2) bx[:, :, 1].uniform_(seq_height_min, seq_height_max) bx[:, :, 2].uniform_(seq_width_min, seq_width_max) @@ -164,10 +168,10 @@ def save_sequence_images(filename, sequences, tr = None, bx = None): delta = -1. if tr is not None: - ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False) + ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False) if bx is not None: - ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False) + ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False) fig.savefig(filename, bbox_inches='tight') @@ -177,7 +181,7 @@ def save_sequence_images(filename, sequences, tr = None, bx = None): class AttentionLayer(nn.Module): def __init__(self, in_channels, out_channels, key_channels): - super(AttentionLayer, self).__init__() + super().__init__() self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False) @@ -186,9 +190,9 @@ class AttentionLayer(nn.Module): Q = self.conv_Q(x) K = self.conv_K(x) V = self.conv_V(x) - A = Q.permute(0, 2, 1).matmul(K).softmax(2) - x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1) - return x + A = einsum('nct,ncs->nts', Q, K).softmax(2) + y = einsum('nts,ncs->nct', A, V) + return y def __repr__(self): return self._get_name() + \ @@ -201,7 +205,8 @@ class AttentionLayer(nn.Module): def attention(self, x): Q = self.conv_Q(x) K = self.conv_K(x) - return Q.permute(0, 2, 1).matmul(K).softmax(2) + A = einsum('nct,ncs->nts', Q, K).softmax(2) + return A ###################################################################### @@ -309,13 +314,16 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu') test_outputs = test_outputs.detach().to('cpu') test_targets = test_targets.detach().to('cpu') +test_bx = test_bx.detach().to('cpu') +test_tr = test_tr.detach().to('cpu') for k in range(15): save_sequence_images(