X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=b2f7d7dc5f750610333d03b7da6c183d83ff7a7a;hb=16cb07f99cf770fb4e97824f874a68cbddd4c1cf;hp=421aee49f2f0005f7650ea9836fde801fc5598e8;hpb=db7cefe4fefb381e56f1292d5bbe4a18c76afb47;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 421aee4..b2f7d7d 100755 --- a/tasks.py +++ b/tasks.py @@ -76,6 +76,7 @@ class Task: import problems + class SandBox(Task): def __init__( self, @@ -1134,8 +1135,8 @@ class RPL(Task): ) if save_attention_image is not None: - ns=torch.randint(self.test_input.size(0),(1,)).item() - input = self.test_input[ns:ns+1].clone() + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 input = input[:, :last].to(self.device)