From 74a872447b15747c7eb576b24344ed6b48d01642 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 16 Jun 2017 14:19:06 +0200 Subject: [PATCH] Prints the ETA. --- cnn-svrt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cnn-svrt.py b/cnn-svrt.py index 5913345..06e58bd 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -146,6 +146,8 @@ def train_model(model, train_set): optimizer = optim.SGD(model.parameters(), lr = 1e-2) + start_t = time.time() + for e in range(0, args.nb_epochs): acc_loss = 0.0 for b in range(0, train_set.nb_batches): @@ -157,6 +159,8 @@ def train_model(model, train_set): loss.backward() optimizer.step() log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss)) + dt = (time.time() - t) / (e + 1) + print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL) return model -- 2.20.1