X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=hallu.py;h=de251884c8c8bafac1e307030ccce1f61ee1935e;hb=d09d91f2b5b594f91a757134c5ce014ae8d68a9a;hp=b738b521ebe1008958f4e44976e938d4af3c0db4;hpb=aacb2bf640ba8342bb49f3a6c285d00fac523540;p=pytorch.git diff --git a/hallu.py b/hallu.py index b738b52..de25188 100755 --- a/hallu.py +++ b/hallu.py @@ -14,7 +14,7 @@ from torch.nn import functional as F class MultiScaleEdgeEnergy(torch.nn.Module): def __init__(self): - super(MultiScaleEdgeEnergy, self).__init__() + super().__init__() k = torch.exp(- torch.tensor([[-2., -1., 0., 1., 2.]])**2 / 2) k = (k.t() @ k).view(1, 1, 5, 5) self.gaussian_5x5 = torch.nn.Parameter(k / k.sum()).requires_grad_(False)