Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 17 Mar 2024 21:42:46 +0000 (22:42 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 17 Mar 2024 21:42:46 +0000 (22:42 +0100)
ideal_rnn.py [new file with mode: 0755]
tiny_vae.py

diff --git a/ideal_rnn.py b/ideal_rnn.py
new file mode 100755 (executable)
index 0000000..16d6059
--- /dev/null
@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+
+import torch
+
+######################################################################
+
+
+def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000):
+    n_star = torch.randint(N, (1,)).item()
+    r = torch.zeros(D)
+    for k in range(nb_max_sequences):
+        X = torch.randn(N, D)
+        y = X[n_star] + torch.randn(D) * 0.5
+        for n in range(X.size(0)):
+            r = fun_f(N, k, n, X[n], r)
+        r, n_star_hat = fun_g(N, k, y, r)
+        if n_star_hat is not None:
+            return k + 1, n_star_hat == n_star
+    return -1, False
+
+
+def multi_test(fun_f, fun_g):
+    result = {False: [], True: []}
+    N = 100
+    D = 25
+    for u in range(100):
+        nb_realizations, correctness = single_test(D, N, fun_f, fun_g)
+        result[correctness].append(nb_realizations)
+
+    return torch.tensor(result[False]), torch.tensor(result[True])
+
+
+######################################################################
+
+d_best_id = 0
+d_best_mean = 1
+d_current_id = 2
+d_current_sum = 3
+d_current_sum_sq = 4
+d_current_nb = 5
+d_content = 6
+
+# N is the sequence length
+# k the index of the realization
+# n the index of the X in the current realization
+# x is X^k_n
+# r is R^k
+
+
+def fun_f(N, k, n, x, r):
+    if k == 0 and n == 0:
+        r[d_best_mean] = 1e9
+        r[d_current_id] = 0
+        r[d_current_sum] = 0.0
+        r[d_current_sum_sq] = 0.0
+        r[d_current_nb] = 0
+
+    if n == r[d_current_id]:
+        r[d_content:] = x[d_content:]
+
+    return r
+
+
+def fun_g(N, k, y, r):
+    current_mean = r[d_current_sum] / r[d_current_nb]
+    current_std = (
+        (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item()
+    )
+
+    if (
+        r[d_current_nb] > 1
+        and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs()
+    ):
+        if current_mean <= r[d_best_mean]:
+            r[d_best_id] = r[d_current_id]
+            r[d_best_mean] = current_mean
+
+        r[d_current_nb] = 0
+        r[d_current_sum] = 0
+        r[d_current_sum_sq] = 0
+        r[d_current_id] += 1
+
+        if r[d_current_id] == N:
+            return r, r[d_best_id].long().item()
+
+    norm = (y[d_content:] - r[d_content:]).norm()
+    r[d_current_nb] += 1
+    r[d_current_sum] += norm
+    r[d_current_sum_sq] += norm**2
+
+    return r, None
+
+
+######################################################################
+
+r_failure, r_succes = multi_test(fun_f, fun_g)
+
+n_failures = r_failure.size(0)
+n_successes = r_succes.size(0)
+
+print(
+    f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})"
+)
+print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})")
+
+######################################################################
index fa09831..4d11c7f 100755 (executable)
@@ -175,12 +175,12 @@ def save_images(model, prefix=""):
     def save_image(x, filename):
         x = x * train_std + train_mu
         x = x.clamp(min=0, max=255) / 255
-        torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+        torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0)
         log_string(f"wrote {filename}")
 
     # Save a bunch of train images
 
-    x = train_input[:256]
+    x = train_input[:36]
     save_image(x, f"{prefix}train_input.png")
 
     # Save the same images after encoding / decoding
@@ -194,7 +194,7 @@ def save_images(model, prefix=""):
 
     # Save a bunch of test images
 
-    x = test_input[:256]
+    x = test_input[:36]
     save_image(x, f"{prefix}input.png")
 
     # Save the same images after encoding / decoding