Update.
[pytorch.git] / poly.py
diff --git a/poly.py b/poly.py
index 818742b..8aac207 100755 (executable)
--- a/poly.py
+++ b/poly.py
@@ -9,6 +9,7 @@
 
 import torch
 
+
 def pol_prod(a, b):
     m = a[:, None] * b[None, :]
     mm = m.new()
@@ -16,18 +17,26 @@ def pol_prod(a, b):
     k = torch.arange(a.size(0))[:, None] + torch.arange(b.size(0))[None, :]
     kk = k.new()
     kk.set_(k.storage(), 0, (k.size(0), k.size(0) + k.size(1) - 1), (k.size(1) - 1, 1))
-    q = (kk == torch.arange(a.size(0) + b.size(0) - 1)[None, :])
+    q = kk == torch.arange(a.size(0) + b.size(0) - 1)[None, :]
     return (mm * q).sum(0)
 
+
+def pol_eval(a, x):
+    d = torch.arange(a.size(0))
+    return (x[:, None].pow(d[None, :]) * a[None, :]).sum(1)
+
+
 def pol_prim(a):
     n = torch.arange(a.size(0) + 1).float()
     n[1:] = a / n[1:]
     return n
 
+
 ######################################################################
 
-if __name__ == '__main__':
-    a = torch.tensor([1., 2., 3.])
-    b = torch.tensor([2., 5.])
+if __name__ == "__main__":
+    a = torch.tensor([1.0, 2.0, 3.0])
+    b = torch.tensor([2.0, 5.0])
     print(pol_prod(a, b))
     print(pol_prim(b))
+    print(pol_eval(a, torch.tensor([0.0, 1.0, 2.0])))