X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=e973291a15517bb3c33a96f6f3f1ab7db4545b86;hb=38c9162209ddc1894da6805a3c7459d8c2b3a13d;hp=4a332b82ab26318d210b479dcbebb5bc25ec6b38;hpb=84748a01e6d3c26037412592ce147b7753ce6117;p=mygpt.git diff --git a/main.py b/main.py index 4a332b8..e973291 100755 --- a/main.py +++ b/main.py @@ -204,7 +204,7 @@ class TaskPicoCLVR(Task): t_generated = [ ] for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ] input = torch.tensor(t, device = self.device) output = model(input) logits = output[0, -1]