Update.
[pytorch.git] / eingather.py
index 588aacb..734edbe 100755 (executable)
@@ -15,11 +15,11 @@ def eingather(op, src, *indexes):
     s_indexes = re.findall("\(([^)]*)\)", s_src)
     s_src = re.sub("\([^)]*\)", "_", s_src)
 
-    all_sizes = tuple(d for s in ( src, ) + indexes for d in s.size())
+    all_sizes = tuple(d for s in (src,) + indexes for d in s.size())
     s_all = "".join([s_src] + s_indexes)
     shape = tuple(all_sizes[s_all.index(v)] for v in s_dst)
 
-    def do(x,s_x):
+    def do(x, s_x):
         idx = []
         n_index = 0
 
@@ -39,7 +39,8 @@ def eingather(op, src, *indexes):
 
         return x[idx]
 
-    return do(src,s_src)
+    return do(src, s_src)
+
 
 #######################
 
@@ -47,7 +48,7 @@ src = torch.rand(3, 5, 7, 11)
 index1 = torch.randint(src.size(2), (src.size(3), src.size(1), src.size(3)))
 index2 = torch.randint(src.size(3), (src.size(1),))
 
-# I want result[a, c, e] = src[c, a, index1[e, a, e], index2[a], e]
+# result[a, c, e] = src[c, a, index1[e, a, e], index2[a]]
 
 result = eingather("ca(eae)(a) -> ace", src, index1, index2)