From: Francois Fleuret Date: Tue, 26 Jul 2022 10:37:41 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=1e7c259b1dd038a0f45dba96e872cd1121f38f96 Update. --- diff --git a/main.py b/main.py index c810eef..4a332b8 100755 --- a/main.py +++ b/main.py @@ -498,7 +498,7 @@ for k in range(nb_epochs_finished, nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0)