+ a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
+ xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
+ a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
+ xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
+ print(f"{xf.size()=} {x.size()=}")
+ yf = (
+ (
+ (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
+ * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
+ * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
+ * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
+ )
+ .max(dim=1)
+ .values
+ )
+
+ full_input, full_targets = xf,yf
+
+ q_full_input = quantize(full_input, -1, 1)
+ full_input = dequantize(q_full_input, -1, 1)
+
+ for k in range(q_full_input[:10].size(0)):
+ with open(f"example_full_{k:04d}.dat", "w") as f:
+ for u, c in zip(full_input[k], full_targets[k]):
+ f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+ for k in range(q_train_input[:10].size(0)):
+ with open(f"example_train_{k:04d}.dat", "w") as f: