Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 1d52b6d..69731ff 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, argparse, time, tqdm, os
+import math, sys, argparse, time, tqdm, os, datetime
 
 import torch, torchvision
 from torch import nn
@@ -257,9 +257,9 @@ default_task_args = {
         "nb_test_samples": 10000,
     },
     "memory": {
-        "model": "4M",
+        "model": "37M",
         "batch_size": 100,
-        "nb_train_samples": 5000,
+        "nb_train_samples": 25000,
         "nb_test_samples": 1000,
     },
     "mixing": {
@@ -718,6 +718,8 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
+time_pred_result = None
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
@@ -776,6 +778,13 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             deterministic_synthesis=args.deterministic_synthesis,
         )
 
+        time_current_result = datetime.datetime.now()
+        if time_pred_result is not None:
+            log_string(
+                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+            )
+        time_pred_result = time_current_result
+
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,
         "model_state": model.state_dict(),