Update.
authorFrançois Fleuret <fleuret@meta.com>
Fri, 13 Jun 2025 12:24:17 +0000 (14:24 +0200)
committerFrançois Fleuret <fleuret@meta.com>
Fri, 13 Jun 2025 12:24:17 +0000 (14:24 +0200)
tinyae.py

index b4f3aba..0baa5a2 100755 (executable)
--- a/tinyae.py
+++ b/tinyae.py
@@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std)
 
 ######################################################################
 
-for epoch in range(args.nb_epochs):
-    acc_loss = 0
+for n_epoch in range(args.nb_epochs):
+    acc_train_loss = 0
 
     for input in train_input.split(args.batch_size):
         output = model(input)
-        loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
+        train_loss = F.mse_loss(output, input)
 
         optimizer.zero_grad()
-        loss.backward()
+        train_loss.backward()
         optimizer.step()
 
-        acc_loss += loss.item()
+        acc_train_loss += train_loss.detach().item() * input.size(0)
 
-    log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
+    train_loss = acc_train_loss / train_input.size(0)
+
+    log_string(f"train_loss {n_epoch} {train_loss}")
 
 ######################################################################