X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=hallu.py;h=9d2706db39b8a954c74557687967eda18cf709ac;hp=de251884c8c8bafac1e307030ccce1f61ee1935e;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/hallu.py b/hallu.py index de25188..9d2706d 100755 --- a/hallu.py +++ b/hallu.py @@ -12,10 +12,11 @@ import PIL, torch, torchvision from torch.nn import functional as F + class MultiScaleEdgeEnergy(torch.nn.Module): def __init__(self): super().__init__() - k = torch.exp(- torch.tensor([[-2., -1., 0., 1., 2.]])**2 / 2) + k = torch.exp(-torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]) ** 2 / 2) k = (k.t() @ k).view(1, 1, 5, 5) self.gaussian_5x5 = torch.nn.Parameter(k / k.sum()).requires_grad_(False) @@ -23,19 +24,20 @@ class MultiScaleEdgeEnergy(torch.nn.Module): u = x.view(-1, 1, x.size(2), x.size(3)) result = 0.0 while min(u.size(2), u.size(3)) > 5: - blurry = F.conv2d(u, self.gaussian_5x5, padding = 2) + blurry = F.conv2d(u, self.gaussian_5x5, padding=2) result += (u - blurry).view(u.size(0), -1).pow(2).sum(1) - u = F.avg_pool2d(u, kernel_size = 2, padding = 1) + u = F.avg_pool2d(u, kernel_size=2, padding=1) return result.view(x.size(0), -1).sum(1) -img = torchvision.transforms.ToTensor()(PIL.Image.open('blacklab.jpg')) + +img = torchvision.transforms.ToTensor()(PIL.Image.open("blacklab.jpg")) img = img.view((1,) + img.size()) ref_input = 0.5 + 0.5 * (img - img.mean()) / img.std() mse_loss = torch.nn.MSELoss() edge_energy = MultiScaleEdgeEnergy() -layers = torchvision.models.vgg16(pretrained = True).features +layers = torchvision.models.vgg16(pretrained=True).features layers.eval() if torch.cuda.is_available(): @@ -43,13 +45,13 @@ if torch.cuda.is_available(): ref_input = ref_input.cuda() layers.cuda() -for l in [ 5, 7, 12, 17, 21, 28 ]: +for l in [5, 7, 12, 17, 21, 28]: model = torch.nn.Sequential(layers[:l]) ref_output = model(ref_input).detach() for n in range(5): input = torch.empty_like(ref_input).uniform_(-0.01, 0.01).requires_grad_() - optimizer = torch.optim.Adam( [ input ], lr = 1e-2) + optimizer = torch.optim.Adam([input], lr=1e-2) for k in range(1000): output = model(input) loss = mse_loss(output, ref_output) + 1e-3 * edge_energy(input) @@ -58,7 +60,7 @@ for l in [ 5, 7, 12, 17, 21, 28 ]: optimizer.step() img = 0.5 + 0.2 * (input - input.mean()) / input.std() - result_name = 'hallu-l%02d-n%02d.png' % (l, n) + result_name = "hallu-l%02d-n%02d.png" % (l, n) torchvision.utils.save_image(img, result_name) - print('Wrote ' + result_name) + print("Wrote " + result_name)