- for i in range(src.dim()):
- v = s_src[i]
- if v == "_":
- index, s_index = indexes[n_index], s_indexes[n_index]
- n_index += 1
+ for i in range(x.dim()):
+ v = s_x[i]
+ if v == "_":
+ idx.append(do(indexes[n_index], s_indexes[n_index]))
+ n_index += 1
+ else:
+ j = s_dst.index(v)
+ a = (
+ torch.arange(x.size(i))
+ .reshape((1,) * j + (-1,) + (1,) * (len(s_dst) - j - 1))
+ .expand(shape)
+ )
+ idx.append(a)
+
+ return x[idx]
+
+ return do(src, s_src)
+
+
+def lambda_eingather(op, src_shape, *indexes_shape):
+ s_src, s_dst = re.search("^([^ ]*) *-> *(.*)", op).groups()
+ s_indexes = re.findall("\(([^)]*)\)", s_src)
+ s_src = re.sub("\([^)]*\)", "_", s_src)