X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=218ff36e0f7e37f67de3e4e227457b61d6414a60;hb=8012a611e9920816fe6ba382b69305242136bc2a;hp=727b196b3a2f008854fb314389254589ab29d715;hpb=a1ae050705970007f965d2586c53e9bd262e46aa;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 727b196..218ff36 100755 --- a/tasks.py +++ b/tasks.py @@ -250,7 +250,13 @@ class PicoCLVR(Task): # Make a list of strings from a tensor def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -888,7 +894,10 @@ class Stack(Task): def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + result *= 1 - ar_mask + masked_inplace_autoregression( model, self.batch_size, @@ -923,10 +932,12 @@ class Stack(Task): stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) + for n in range(result.size(0)): + logger( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + + result *= 1 - ar_mask masked_inplace_autoregression( model, @@ -1448,7 +1459,13 @@ class Grid(Task): # Make a list of strings from a tensor def tensor2str(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its