Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Oct 2023 16:04:38 +0000 (18:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Oct 2023 16:04:38 +0000 (18:04 +0200)
problems.py

index ef48162..819715e 100755 (executable)
@@ -52,6 +52,7 @@ class ProblemDegradation(Problem):
     def compute_nb_correct(self, input, ar_mask, result):
         nb_total = result.size(0)
         nb_correct = 0
+        e=result.new_zeros(self.nb_state_tokens)
 
         for seq in result:
             states = list(seq.split(self.nb_state_tokens))
@@ -60,14 +61,14 @@ class ProblemDegradation(Problem):
 
             d = states[0]
             j=d.sort(descending=True).indices[0]
-            e=d.new_zeros(d.size())
+            e.zero_()
             e[j]=self.value_max
             if (d-e).abs().sum() == 0:
                 nb_errors = 0
                 for k in range(len(states)-1):
                     d=states[k]-states[k+1]
                     j=d.sort(descending=True).indices[0]
-                    e=d.new_zeros(d.size())
+                    e.zero_()
                     e[j]=d[j]
                     e[(j+1)%e.size(0)]=-d[j]//2
                     e[(j-1)%e.size(0)]=-d[j]//2