From 6183291906184569c2206c34588d118cc77f74bb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 7 Jan 2024 11:54:31 +0100 Subject: [PATCH] Update. --- main.py | 13 ++++++++++++- mygpt.py | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 18c0730..cae20f8 100755 --- a/main.py +++ b/main.py @@ -24,6 +24,17 @@ else: ###################################################################### + +def str2bool(x): + x = x.lower() + if x in {"1", "true", "yes"}: + return True + elif x in {"0", "false", "no"}: + return False + else: + raise ValueError + + parser = argparse.ArgumentParser( description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -68,7 +79,7 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5) # legacy -parser.add_argument("--legacy_lr_schedule", action="store_true", default=False) +parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True) parser.add_argument("--legacy_large_lr", type=float, default=1e-4) diff --git a/mygpt.py b/mygpt.py index f97af49..f10f1fe 100755 --- a/mygpt.py +++ b/mygpt.py @@ -472,15 +472,18 @@ def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device): fb_body = fb_body.cumsum(dim=2) fb_start = fb_start * (fb_body == 1) - # pick past starting source times - src_time = ( - fb_start + # t_s = t0-(t0//L * R)*L + + t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :] + src_time = fb_start * ( + t + - CL * ( - torch.rand(fb_start.size(), device=fb_start.device) - * (torch.arange(fb_start.size(2), device=fb_start.device) - CL)[ - None, None, : - ] - ).long() + 1 + + ( + torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1) + ).long() + ) ) src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL] src_time = src_time.cumsum(dim=2) -- 2.20.1