pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)