From: François Fleuret Date: Sun, 23 Jul 2023 18:29:54 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=16cb07f99cf770fb4e97824f874a68cbddd4c1cf;p=picoclvr.git Update. --- diff --git a/problems.py b/problems.py index 78bb64e..5161587 100755 --- a/problems.py +++ b/problems.py @@ -156,4 +156,3 @@ class ProblemAddition(Problem): # for strain, stest in zip(train_seq, test_seq): # s = torch.cat((strain, stest), 0) - diff --git a/tasks.py b/tasks.py index 421aee4..b2f7d7d 100755 --- a/tasks.py +++ b/tasks.py @@ -76,6 +76,7 @@ class Task: import problems + class SandBox(Task): def __init__( self, @@ -1134,8 +1135,8 @@ class RPL(Task): ) if save_attention_image is not None: - ns=torch.randint(self.test_input.size(0),(1,)).item() - input = self.test_input[ns:ns+1].clone() + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 input = input[:, :last].to(self.device)