Update.
[pytorch.git] / rmax.py
1 #!/usr/bin/env python
2
3 import torch
4
5 ##################################################
6
7
8 def rmax(x):
9     a = x.max(-1, keepdim=True)
10     i = torch.arange(x.size(-1) - 1)[None, :]
11     y = torch.cat(
12         (
13             (i < a.indices) * (x - a.values)[:, :-1]
14             + (i >= a.indices) * (a.values - x)[:, 1:],
15             a.values,
16         ),
17         -1,
18     )
19     return y
20
21
22 def rmax_back(y):
23     u = torch.nn.functional.pad(y, (1, -1))
24     x = (
25         (y < 0) * (y[:, -1:] + y)
26         + (y >= 0) * (u < 0) * (y[:, -1:])
27         + (y >= 0) * (u >= 0) * (y[:, -1:] - u)
28     )
29     return x
30
31
32 ##################################################
33
34 x = torch.randn(3, 14)
35 y = rmax(x)
36 print(f"{x.size()=} {x.max(-1).values=}")
37 print(f"{y.size()=} {y[:,-1]=}")
38
39 z = rmax_back(y)
40 print(f"{(z-x).abs().max()=}")