######################################################################
-for epoch in range(args.nb_epochs):
- acc_loss = 0
+for n_epoch in range(args.nb_epochs):
+ acc_train_loss = 0
for input in train_input.split(args.batch_size):
output = model(input)
- loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
+ train_loss = F.mse_loss(output, input)
optimizer.zero_grad()
- loss.backward()
+ train_loss.backward()
optimizer.step()
- acc_loss += loss.item()
+ acc_train_loss += train_loss.detach().item() * input.size(0)
- log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
+ train_loss = acc_train_loss / train_input.size(0)
+
+ log_string(f"train_loss {n_epoch} {train_loss}")
######################################################################