X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=4a332b82ab26318d210b479dcbebb5bc25ec6b38;hb=1e7c259b1dd038a0f45dba96e872cd1121f38f96;hp=c810eef06593c7939e0ad15fe690225509cd4150;hpb=fc570d4ccd5d5dee36271d34ff5c672a50a82101;p=mygpt.git diff --git a/main.py b/main.py index c810eef..4a332b8 100755 --- a/main.py +++ b/main.py @@ -498,7 +498,7 @@ for k in range(nb_epochs_finished, nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0)