Initial commit.
[pytorch.git] / attentiontoy1d.py
index e82894e..b463340 100755 (executable)
@@ -7,7 +7,7 @@
 
 import torch, math, sys, argparse
 
-from torch import nn
+from torch import nn, einsum
 from torch.nn import functional as F
 
 import matplotlib.pyplot as plt
@@ -181,7 +181,7 @@ def save_sequence_images(filename, sequences, tr = None, bx = None):
 
 class AttentionLayer(nn.Module):
     def __init__(self, in_channels, out_channels, key_channels):
-        super(AttentionLayer, self).__init__()
+        super().__init__()
         self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
         self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
         self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
@@ -190,9 +190,9 @@ class AttentionLayer(nn.Module):
         Q = self.conv_Q(x)
         K = self.conv_K(x)
         V = self.conv_V(x)
-        A = Q.permute(0, 2, 1).matmul(K).softmax(2)
-        x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
-        return x
+        A = einsum('nct,ncs->nts', Q, K).softmax(2)
+        y = einsum('nts,ncs->nct', A, V)
+        return y
 
     def __repr__(self):
         return self._get_name() + \
@@ -205,7 +205,8 @@ class AttentionLayer(nn.Module):
     def attention(self, x):
         Q = self.conv_Q(x)
         K = self.conv_K(x)
-        return Q.permute(0, 2, 1).matmul(K).softmax(2)
+        A = einsum('nct,ncs->nts', Q, K).softmax(2)
+        return A
 
 ######################################################################
 
@@ -321,6 +322,8 @@ if args.with_attention:
 test_input = test_input.detach().to('cpu')
 test_outputs = test_outputs.detach().to('cpu')
 test_targets = test_targets.detach().to('cpu')
+test_bx = test_bx.detach().to('cpu')
+test_tr = test_tr.detach().to('cpu')
 
 for k in range(15):
     save_sequence_images(