X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=pscan.py;fp=pscan.py;h=c5478233b5c1156b36a7ac082e2aa8c0db613805;hp=3526c31ed601270bb507f2c7771284ad455b5c62;hb=24557d01df2fea576a593bc039318e21a06f7ae4;hpb=674dd7c7adde6b4a9aaa5afd57dbe1d063a47fcc diff --git a/pscan.py b/pscan.py index 3526c31..c547823 100755 --- a/pscan.py +++ b/pscan.py @@ -114,3 +114,8 @@ if __name__ == "__main__": print((gA - gA_ref).norm()) print((gX - gX_ref).norm()) print((gY_init - gY_init_ref).norm()) + + Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) + + print((Y - torch.cat([Y1, Y2], dim=1)).norm())