Update.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2017 06:05:25 +0000 (08:05 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2017 06:05:25 +0000 (08:05 +0200)
cnn-svrt.py

index a41d42c..d6c7169 100755 (executable)
@@ -32,6 +32,7 @@ from colorama import Fore, Back, Style
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 from torch import FloatTensor as Tensor
@@ -73,6 +74,9 @@ parser.add_argument('--batch_size',
 parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = -1)
+
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
@@ -295,6 +299,21 @@ class vignette_logger():
             )
             self.last_t = t
 
+def save_examplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
+
 ######################################################################
 
 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
@@ -357,6 +376,10 @@ for problem_number in map(int, args.problems.split(',')):
             train_set.nb_samples / (time.time() - t))
         )
 
+        if args.nb_exemplar_vignettes > 0:
+            save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'examplar_{:d}.png'.format(problem_number))
+
         if args.validation_error_threshold > 0.0:
             validation_set = VignetteSet(problem_number,
                                          args.nb_validation_samples, args.batch_size,