######################################################################
 
-    def sigma_for_grids(self, input):
+    def sigma_for_grids(self, input, block_order=(0, 1, 2, 3)):
         l = input.size(1) // 4
         sigma = input.new(input.size())
         r = sigma.view(sigma.size(0), 4, l)
-        r[:, 0] = 0 * l
-        r[:, 1] = 1 * l
-        r[:, 2] = 2 * l
-        r[:, 3] = 3 * l
+        r[:, 0, :] = block_order[0] * l
+        r[:, 1, :] = block_order[1] * l
+        r[:, 2, :] = block_order[2] * l
+        r[:, 3, :] = block_order[3] * l
         r[:, :, 1:] += (
             torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices
         ) + 1