projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
76e62a5
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sun, 22 Oct 2023 17:54:13 +0000
(19:54 +0200)
committer
François Fleuret
<francois@fleuret.org>
Sun, 22 Oct 2023 17:54:13 +0000
(19:54 +0200)
problems.py
patch
|
blob
|
history
diff --git
a/problems.py
b/problems.py
index
b8fcdb3
..
632c059
100755
(executable)
--- a/
problems.py
+++ b/
problems.py
@@
-298,13
+298,21
@@
class ProblemMixing(Problem):
# m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
# m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
- i = torch.arange(self.height).reshape(1,-1,1).expand(nb,self.height,self.width)
- j = torch.arange(self.width).reshape(1,1,-1).expand(nb,self.height,self.width)
+ i = (
+ torch.arange(self.height)
+ .reshape(1, -1, 1)
+ .expand(nb, self.height, self.width)
+ )
+ j = (
+ torch.arange(self.width)
+ .reshape(1, 1, -1)
+ .expand(nb, self.height, self.width)
+ )
- ri = torch.randint(self.height, (nb,)).reshape(nb,
1,
1)
- rj = torch.randint(self.width, (nb,)).reshape(nb,
1,
1)
+ ri = torch.randint(self.height, (nb,)).reshape(nb,
1,
1)
+ rj = torch.randint(self.width, (nb,)).reshape(nb,
1,
1)
- m = 1 - torch.logical_or(i
==ri,j==
rj).long().flatten(1)
+ m = 1 - torch.logical_or(i
== ri, j ==
rj).long().flatten(1)
y = (y * m + self.height * self.width * (1 - m)).reshape(
nb, self.height, self.width
y = (y * m + self.height * self.width * (1 - m)).reshape(
nb, self.height, self.width
@@
-313,16
+321,20
@@
class ProblemMixing(Problem):
return y
def start_error(self, x):
return y
def start_error(self, x):
- i = torch.arange(self.height, device=x.device).reshape(1,
-1,
1).expand_as(x)
- j = torch.arange(self.width, device=x.device).reshape(1,
1,
-1).expand_as(x)
+ i = torch.arange(self.height, device=x.device).reshape(1,
-1,
1).expand_as(x)
+ j = torch.arange(self.width, device=x.device).reshape(1,
1,
-1).expand_as(x)
- ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
- rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)
+ ri = (
+ (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1, 1, 1)
+ )
+ rj = (
+ (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1, 1, 1)
+ )
- m = 1 - torch.logical_or(i
==ri,j==
rj).long().flatten(1)
+ m = 1 - torch.logical_or(i
== ri, j ==
rj).long().flatten(1)
x = x.flatten(1)
x = x.flatten(1)
- u = torch.arange(self.height * self.width, device
=
x.device).reshape(1, -1)
+ u = torch.arange(self.height * self.width, device
=
x.device).reshape(1, -1)
d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
return d
d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
return d
@@
-390,7
+402,15
@@
class ProblemMixing(Problem):
return " | ".join(
[
" ".join(
return " | ".join(
[
" ".join(
- ["-".join([f"{x:02d}" if x < self.height * self.width else "**" for x in s]) for s in r.split(self.width)]
+ [
+ "-".join(
+ [
+ f"{x:02d}" if x < self.height * self.width else "**"
+ for x in s
+ ]
+ )
+ for s in r.split(self.width)
+ ]
)
for r in seq.split(self.height * self.width)
]
)
for r in seq.split(self.height * self.width)
]