Update.
[mygptrnn.git] / blanket.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10
11 class Blanket(torch.autograd.Function):
12     @staticmethod
13     def normalize(x):
14         y = x.flatten(1)
15         y /= y.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-6
16         y *= math.sqrt(y.numel() / y.size(0))
17
18     @staticmethod
19     def forward(ctx, x):
20         x = x.clone()
21         # Normalize the forward
22         Blanket.normalize(x)
23         return x
24
25     @staticmethod
26     def backward(ctx, grad_output):
27         grad_output = grad_output.clone()
28         # Normalize the gradient
29         Blanket.normalize(grad_output)
30         return grad_output
31
32
33 blanket = Blanket.apply
34
35 ######################################################################
36
37 if __name__ == "__main__":
38     x = torch.rand(2, 3).requires_grad_()
39     y = blanket(x) * 10
40     print(y.pow(2).sum())
41     z = y.sin().sum()
42     g = torch.autograd.grad(z, x)[0]
43
44     print(g.pow(2).sum())