Heavy fix.
[pysvrt.git] / cnn-svrt.py
index 7bef242..694f035 100755 (executable)
@@ -184,15 +184,19 @@ for problem_number in range(1, 24):
     for p in model.parameters(): nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
+    need_to_train = False
     try:
-
         model.load_state_dict(torch.load(model_filename))
         log_string('loaded_model ' + model_filename)
-
     except:
+        need_to_train = True
+
+    if need_to_train:
 
         log_string('training_model ' + model_filename)
 
+        t = time.time()
+
         if args.compress_vignettes:
             train_set = CompressedVignetteSet(problem_number,
                                               args.nb_train_batches, args.batch_size,
@@ -208,6 +212,10 @@ for problem_number in range(1, 24):
                                    args.nb_test_batches, args.batch_size,
                                    cuda=torch.cuda.is_available())
 
+        log_string('data_generation {:0.2f} samples / s'.format(
+            (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
+        )
+
         train_model(model, train_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)