Update.
[pytorch.git] / poly.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 ######################################################################
9
10 import torch
11
12
13 def pol_prod(a, b):
14     m = a[:, None] * b[None, :]
15     mm = m.new()
16     mm.set_(m.storage(), 0, (m.size(0), m.size(0) + m.size(1) - 1), (m.size(1) - 1, 1))
17     k = torch.arange(a.size(0))[:, None] + torch.arange(b.size(0))[None, :]
18     kk = k.new()
19     kk.set_(k.storage(), 0, (k.size(0), k.size(0) + k.size(1) - 1), (k.size(1) - 1, 1))
20     q = kk == torch.arange(a.size(0) + b.size(0) - 1)[None, :]
21     return (mm * q).sum(0)
22
23
24 def pol_eval(a, x):
25     d = torch.arange(a.size(0))
26     return (x[:, None].pow(d[None, :]) * a[None, :]).sum(1)
27
28
29 def pol_prim(a):
30     n = torch.arange(a.size(0) + 1).float()
31     n[1:] = a / n[1:]
32     return n
33
34
35 ######################################################################
36
37 if __name__ == "__main__":
38     a = torch.tensor([1.0, 2.0, 3.0])
39     b = torch.tensor([2.0, 5.0])
40     print(pol_prod(a, b))
41     print(pol_prim(b))
42     print(pol_eval(a, torch.tensor([0.0, 1.0, 2.0])))