From: François Fleuret Date: Sat, 21 Oct 2023 16:04:38 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=picoclvr.git;a=commitdiff_plain;h=cb7001fcd7a75eaeaca9ae66fce37e372acf8cc1 Update. --- diff --git a/problems.py b/problems.py index ef48162..819715e 100755 --- a/problems.py +++ b/problems.py @@ -52,6 +52,7 @@ class ProblemDegradation(Problem): def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 + e=result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,14 +61,14 @@ class ProblemDegradation(Problem): d = states[0] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=self.value_max if (d-e).abs().sum() == 0: nb_errors = 0 for k in range(len(states)-1): d=states[k]-states[k+1] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=d[j] e[(j+1)%e.size(0)]=-d[j]//2 e[(j-1)%e.size(0)]=-d[j]//2