X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=b463340015b2814b1f2b71c8a0611b7270d499a7;hb=5ab8211805831629148d7b436b8770590f1987b0;hp=d389f0c9b9114ac52378a21c65f472c5f100e746;hpb=fdd573490e517d38fb0477ae1b5df12b74718d45;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index d389f0c..b463340 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -181,7 +181,7 @@ def save_sequence_images(filename, sequences, tr = None, bx = None): class AttentionLayer(nn.Module): def __init__(self, in_channels, out_channels, key_channels): - super(AttentionLayer, self).__init__() + super().__init__() self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False) self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)