Update.
[picoclvr.git] / qmlp.py
diff --git a/qmlp.py b/qmlp.py
index b58598a..abebfc1 100755 (executable)
--- a/qmlp.py
+++ b/qmlp.py
@@ -92,23 +92,37 @@ def generate_sets_and_params(
     test_input = dequantize(q_test_input, -1, 1)
 
     if save_as_examples:
-        a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
-        xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
-                        a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
-        xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
+        a = (
+            2
+            * torch.arange(nb_quantization_levels).float()
+            / (nb_quantization_levels - 1)
+            - 1
+        )
+        xf = torch.cat(
+            [
+                a[:, None, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+                a[None, :, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+            ],
+            2,
+        )
+        xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
         print(f"{xf.size()=} {x.size()=}")
         yf = (
             (
-                (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
-                * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
-                * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
-                * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
+                (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
+                * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
+                * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
+                * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
             )
             .max(dim=1)
             .values
         )
 
-        full_input, full_targets = xf,yf
+        full_input, full_targets = xf, yf
 
         q_full_input = quantize(full_input, -1, 1)
         full_input = dequantize(q_full_input, -1, 1)
@@ -208,8 +222,12 @@ def generate_sets_and_params(
 
 
 def evaluate_q_params(
-        q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
-        save_as_examples=False,
+    q_params,
+    q_set,
+    batch_size=25,
+    device=torch.device("cpu"),
+    nb_mlps_per_batch=1024,
+    save_as_examples=False,
 ):
     errors = []
     nb_mlps = q_params.size(0)