X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=969b47f7ac144870b0598514e6593a64e52daee8;hb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;hp=c51035c118f6480f2f33560b8aeea020c4cfbf28;hpb=57332f677ef5ee535707c1b83a541aa0e79508e6;p=mygptrnn.git diff --git a/main.py b/main.py index c51035c..969b47f 100755 --- a/main.py +++ b/main.py @@ -16,14 +16,6 @@ import mygpt, tasks, problems ###################################################################### -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") - -###################################################################### - def str2bool(x): x = x.lower() @@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--force_cpu", type=str2bool, default=False) + ######################################## parser.add_argument("--nb_epochs", type=int, default=50) @@ -217,6 +211,14 @@ if args.result_dir is None: ###################################################################### +if not args.force_cpu and torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + default_task_args = { "addition": { "model": "352M", @@ -832,7 +834,7 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) -time_pred_result = None +time_pred_result = datetime.datetime.now() it = 0 @@ -910,10 +912,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" - ) + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) time_pred_result = time_current_result checkpoint = {