X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=b463340015b2814b1f2b71c8a0611b7270d499a7;hb=d74d7be5abef26c78d014bd179f2c52f81aca65b;hp=e82894e650a05c36b8e6779dcf840fa3182319d7;hpb=d2c145b4306d5c36094618ff7e7323c5d083e1df;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index e82894e..b463340 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -7,7 +7,7 @@ import torch, math, sys, argparse -from torch import nn +from torch import nn, einsum from torch.nn import functional as F import matplotlib.pyplot as plt @@ -181,7 +181,7 @@ def save_sequence_images(filename, sequences, tr = None, bx = None): class AttentionLayer(nn.Module): def __init__(self, in_channels, out_channels, key_channels): - super(AttentionLayer, self).__init__() + super().__init__() self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False) @@ -190,9 +190,9 @@ class AttentionLayer(nn.Module): Q = self.conv_Q(x) K = self.conv_K(x) V = self.conv_V(x) - A = Q.permute(0, 2, 1).matmul(K).softmax(2) - x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1) - return x + A = einsum('nct,ncs->nts', Q, K).softmax(2) + y = einsum('nts,ncs->nct', A, V) + return y def __repr__(self): return self._get_name() + \ @@ -205,7 +205,8 @@ class AttentionLayer(nn.Module): def attention(self, x): Q = self.conv_Q(x) K = self.conv_K(x) - return Q.permute(0, 2, 1).matmul(K).softmax(2) + A = einsum('nct,ncs->nts', Q, K).softmax(2) + return A ###################################################################### @@ -321,6 +322,8 @@ if args.with_attention: test_input = test_input.detach().to('cpu') test_outputs = test_outputs.detach().to('cpu') test_targets = test_targets.detach().to('cpu') +test_bx = test_bx.detach().to('cpu') +test_tr = test_tr.detach().to('cpu') for k in range(15): save_sequence_images(