projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
02c4828
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 5 Jul 2023 06:47:03 +0000
(08:47 +0200)
committer
François Fleuret
<francois@fleuret.org>
Wed, 5 Jul 2023 06:47:03 +0000
(08:47 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
e1f619c
..
15e6d99
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-1116,6
+1116,9
@@
class TaskExpr(Task):
nb_total = input.size(0)
nb_correct = (input == result).long().min(1).values.sum()
nb_total = input.size(0)
nb_correct = (input == result).long().min(1).values.sum()
+ #######################################################################
+ # Comput predicted vs. true variable values
+
values_input = expr.extract_results([self.seq2str(s) for s in input])
max_input = max([max(x.values()) for x in values_input])
values_result = expr.extract_results([self.seq2str(s) for s in result])
values_input = expr.extract_results([self.seq2str(s) for s in input])
max_input = max([max(x.values()) for x in values_input])
values_result = expr.extract_results([self.seq2str(s) for s in result])
@@
-1123,9
+1126,9
@@
class TaskExpr(Task):
[-1 if len(x) == 0 else max(x.values()) for x in values_result]
)
[-1 if len(x) == 0 else max(x.values()) for x in values_result]
)
- nb_missing
, nb_predicted = torch.zeros(max_input + 1), torch.zeros(
- max_input + 1, max_result + 1
- )
+ nb_missing
= torch.zeros(max_input + 1)
+ nb_predicted = torch.zeros(max_input + 1, max_result + 1)
+
for i, r in zip(values_input, values_result):
for n, vi in i.items():
vr = r.get(n)
for i, r in zip(values_input, values_result):
for n, vi in i.items():
vr = r.get(n)
@@
-1133,6
+1136,7
@@
class TaskExpr(Task):
nb_missing[vi] += 1
else:
nb_predicted[vi, vr] += 1
nb_missing[vi] += 1
else:
nb_predicted[vi, vr] += 1
+ ######################################################################
return nb_total, nb_correct
return nb_total, nb_correct