X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mi_estimator.py;h=47381ef3007bf511341dbb0aa3dbc654a45e29a2;hb=d74d7be5abef26c78d014bd179f2c52f81aca65b;hp=68fd51f75e7a226b09da6e0b2667228537a848ca;hpb=aacb2bf640ba8342bb49f3a6c285d00fac523540;p=pytorch.git diff --git a/mi_estimator.py b/mi_estimator.py index 68fd51f..47381ef 100755 --- a/mi_estimator.py +++ b/mi_estimator.py @@ -226,7 +226,7 @@ def create_sequences_pairs(train = False): class NetForImagePair(nn.Module): def __init__(self): - super(NetForImagePair, self).__init__() + super().__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -257,7 +257,7 @@ class NetForImagePair(nn.Module): class NetForImageValuesPair(nn.Module): def __init__(self): - super(NetForImageValuesPair, self).__init__() + super().__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -306,7 +306,7 @@ class NetForSequencePair(nn.Module): ) def __init__(self): - super(NetForSequencePair, self).__init__() + super().__init__() self.nc = 32 self.nh = 256