X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=poly.py;h=8aac2078443f7e859f388bc6cf60cea1af6f8679;hp=1b157a82275b591efcdd74d0c10a08d584282ef8;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/poly.py b/poly.py index 1b157a8..8aac207 100755 --- a/poly.py +++ b/poly.py @@ -9,6 +9,7 @@ import torch + def pol_prod(a, b): m = a[:, None] * b[None, :] mm = m.new() @@ -16,23 +17,26 @@ def pol_prod(a, b): k = torch.arange(a.size(0))[:, None] + torch.arange(b.size(0))[None, :] kk = k.new() kk.set_(k.storage(), 0, (k.size(0), k.size(0) + k.size(1) - 1), (k.size(1) - 1, 1)) - q = (kk == torch.arange(a.size(0) + b.size(0) - 1)[None, :]) + q = kk == torch.arange(a.size(0) + b.size(0) - 1)[None, :] return (mm * q).sum(0) + def pol_eval(a, x): d = torch.arange(a.size(0)) return (x[:, None].pow(d[None, :]) * a[None, :]).sum(1) + def pol_prim(a): n = torch.arange(a.size(0) + 1).float() n[1:] = a / n[1:] return n + ###################################################################### -if __name__ == '__main__': - a = torch.tensor([1., 2., 3.]) - b = torch.tensor([2., 5.]) +if __name__ == "__main__": + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([2.0, 5.0]) print(pol_prod(a, b)) print(pol_prim(b)) print(pol_eval(a, torch.tensor([0.0, 1.0, 2.0])))