X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=ideal_rnn.py;fp=ideal_rnn.py;h=16d605990958cf2ee8cafb6ca8fbeded9ea038dd;hp=0000000000000000000000000000000000000000;hb=bbe5b7ddb723696fb5388be950af252cb95eb5fb;hpb=0fdaaceb231d31d53d0c623848b8ac56964bedb5 diff --git a/ideal_rnn.py b/ideal_rnn.py new file mode 100755 index 0000000..16d6059 --- /dev/null +++ b/ideal_rnn.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python + +import torch + +###################################################################### + + +def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000): + n_star = torch.randint(N, (1,)).item() + r = torch.zeros(D) + for k in range(nb_max_sequences): + X = torch.randn(N, D) + y = X[n_star] + torch.randn(D) * 0.5 + for n in range(X.size(0)): + r = fun_f(N, k, n, X[n], r) + r, n_star_hat = fun_g(N, k, y, r) + if n_star_hat is not None: + return k + 1, n_star_hat == n_star + return -1, False + + +def multi_test(fun_f, fun_g): + result = {False: [], True: []} + N = 100 + D = 25 + for u in range(100): + nb_realizations, correctness = single_test(D, N, fun_f, fun_g) + result[correctness].append(nb_realizations) + + return torch.tensor(result[False]), torch.tensor(result[True]) + + +###################################################################### + +d_best_id = 0 +d_best_mean = 1 +d_current_id = 2 +d_current_sum = 3 +d_current_sum_sq = 4 +d_current_nb = 5 +d_content = 6 + +# N is the sequence length +# k the index of the realization +# n the index of the X in the current realization +# x is X^k_n +# r is R^k + + +def fun_f(N, k, n, x, r): + if k == 0 and n == 0: + r[d_best_mean] = 1e9 + r[d_current_id] = 0 + r[d_current_sum] = 0.0 + r[d_current_sum_sq] = 0.0 + r[d_current_nb] = 0 + + if n == r[d_current_id]: + r[d_content:] = x[d_content:] + + return r + + +def fun_g(N, k, y, r): + current_mean = r[d_current_sum] / r[d_current_nb] + current_std = ( + (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item() + ) + + if ( + r[d_current_nb] > 1 + and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs() + ): + if current_mean <= r[d_best_mean]: + r[d_best_id] = r[d_current_id] + r[d_best_mean] = current_mean + + r[d_current_nb] = 0 + r[d_current_sum] = 0 + r[d_current_sum_sq] = 0 + r[d_current_id] += 1 + + if r[d_current_id] == N: + return r, r[d_best_id].long().item() + + norm = (y[d_content:] - r[d_content:]).norm() + r[d_current_nb] += 1 + r[d_current_sum] += norm + r[d_current_sum_sq] += norm**2 + + return r, None + + +###################################################################### + +r_failure, r_succes = multi_test(fun_f, fun_g) + +n_failures = r_failure.size(0) +n_successes = r_succes.size(0) + +print( + f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})" +) +print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})") + +######################################################################