Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index c51035c..969b47f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,14 +16,6 @@ import mygpt, tasks, problems
 
 ######################################################################
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
-
-######################################################################
-
 
 def str2bool(x):
     x = x.lower()
@@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
 ########################################
 
 parser.add_argument("--nb_epochs", type=int, default=50)
@@ -217,6 +211,14 @@ if args.result_dir is None:
 
 ######################################################################
 
+if not args.force_cpu and torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
 default_task_args = {
     "addition": {
         "model": "352M",
@@ -832,7 +834,7 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
 
 it = 0
 
@@ -910,10 +912,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
         time_pred_result = time_current_result
 
     checkpoint = {