projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
421aee4
..
b2f7d7d
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-76,6
+76,7
@@
class Task:
import problems
import problems
+
class SandBox(Task):
def __init__(
self,
class SandBox(Task):
def __init__(
self,
@@
-1134,8
+1135,8
@@
class RPL(Task):
)
if save_attention_image is not None:
)
if save_attention_image is not None:
- ns
=torch.randint(self.test_input.size(0),
(1,)).item()
- input = self.test_input[ns
:ns+
1].clone()
+ ns
= torch.randint(self.test_input.size(0),
(1,)).item()
+ input = self.test_input[ns
: ns +
1].clone()
last = (input != self.t_nul).max(0).values.nonzero().max() + 3
input = input[:, :last].to(self.device)
last = (input != self.t_nul).max(0).values.nonzero().max() + 3
input = input[:, :last].to(self.device)