Update.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:37:41 +0000 (12:37 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:37:41 +0000 (12:37 +0200)
main.py

diff --git a/main.py b/main.py
index c810eef..4a332b8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -498,7 +498,7 @@ for k in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split = 'test'):
             input = input.to(device)
             output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)