Minor update.
[pysvrt.git] / cnn-svrt.py
index 694f035..a6b9cab 100755 (executable)
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
-#  along with selector.  If not, see <http://www.gnu.org/licenses/>.
+#  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
 
 import time
 import argparse
 import math
 
 
 import time
 import argparse
 import math
 
+import distutils.util
+import re
+import signal
+
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
 from colorama import Fore, Back, Style
 
 # Pytorch
 
 import torch
+import torchvision
 
 from torch import optim
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 # SVRT
 
 from torchvision import datasets, transforms, utils
 
 # SVRT
 
-from vignette_set import VignetteSet, CompressedVignetteSet
+import svrtset
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = 'Simple convnet test on the SVRT.',
+    description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
-parser.add_argument('--nb_train_batches',
-                    type = int, default = 1000,
-                    help = 'How many samples for train')
+parser.add_argument('--nb_train_samples',
+                    type = int, default = 100000)
+
+parser.add_argument('--nb_test_samples',
+                    type = int, default = 10000)
 
 
-parser.add_argument('--nb_test_batches',
-                    type = int, default = 100,
-                    help = 'How many samples for test')
+parser.add_argument('--nb_validation_samples',
+                    type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+                    type = float, default = 0.0,
+                    help = 'Early training termination criterion')
 
 parser.add_argument('--nb_epochs',
 
 parser.add_argument('--nb_epochs',
-                    type = int, default = 50,
-                    help = 'How many training epochs')
+                    type = int, default = 50)
 
 parser.add_argument('--batch_size',
 
 parser.add_argument('--batch_size',
-                    type = int, default = 100,
-                    help = 'Mini-batch size')
+                    type = int, default = 100)
 
 parser.add_argument('--log_file',
 
 parser.add_argument('--log_file',
-                    type = str, default = 'cnn-svrt.log',
-                    help = 'Log file name')
+                    type = str, default = 'default.log')
+
+parser.add_argument('--nb_exemplar_vignettes',
+                    type = int, default = 32)
 
 parser.add_argument('--compress_vignettes',
 
 parser.add_argument('--compress_vignettes',
-                    action='store_true', default = False,
+                    type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 
                     help = 'Use lossless compression to reduce the memory footprint')
 
+parser.add_argument('--save_test_mistakes',
+                    type = distutils.util.strtobool, default = 'False')
+
+parser.add_argument('--model',
+                    type = str, default = 'deepnet',
+                    help = 'What model to use')
+
+parser.add_argument('--test_loaded_models',
+                    type = distutils.util.strtobool, default = 'False',
+                    help = 'Should we compute the test errors of loaded models')
+
+parser.add_argument('--problems',
+                    type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
+                    help = 'What problems to process')
+
 args = parser.parse_args()
 
 ######################################################################
 
 args = parser.parse_args()
 
 ######################################################################
 
-log_file = open(args.log_file, 'w')
+log_file = open(args.log_file, 'a')
+log_file.write('\n')
+log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
+log_file.write('\n')
+
+pred_log_t = None
+last_tag_t = time.time()
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
 
 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
 
-def log_string(s):
-    s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
-    log_file.write(s + '\n')
+# Log and prints the string, with a time stamp. Does not log the
+# remark
+
+def log_string(s, remark = ''):
+    global pred_log_t, last_tag_t
+
+    t = time.time()
+
+    if pred_log_t is None:
+        elapsed = 'start'
+    else:
+        elapsed = '+{:.02f}s'.format(t - pred_log_t)
+
+    pred_log_t = t
+
+    if t > last_tag_t + 3600:
+        last_tag_t = t
+        print(Fore.RED + time.ctime() + Style.RESET_ALL)
+
+    log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
     log_file.flush()
     log_file.flush()
-    print(s)
+
+    print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
+          + Style.RESET_ALL
+          + ' ' \
+          + s + Fore.CYAN + remark \
+          + Style.RESET_ALL)
+
+######################################################################
+
+def handler_sigint(signum, frame):
+    log_string('got sigint')
+    exit(0)
+
+def handler_sigterm(signum, frame):
+    log_string('got sigterm')
+    exit(0)
+
+signal.signal(signal.SIGINT, handler_sigint)
+signal.signal(signal.SIGTERM, handler_sigterm)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -104,6 +172,8 @@ def log_string(s):
 # -- full(84x2)        -> 2          1
 
 class AfrozeShallowNet(nn.Module):
 # -- full(84x2)        -> 2          1
 
 class AfrozeShallowNet(nn.Module):
+    name = 'shallownet'
+
     def __init__(self):
         super(AfrozeShallowNet, self).__init__()
         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
     def __init__(self):
         super(AfrozeShallowNet, self).__init__()
         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
@@ -111,7 +181,6 @@ class AfrozeShallowNet(nn.Module):
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
-        self.name = 'shallownet'
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
@@ -124,7 +193,182 @@ class AfrozeShallowNet(nn.Module):
 
 ######################################################################
 
 
 ######################################################################
 
-def train_model(model, train_set):
+# Afroze's DeepNet
+
+class AfrozeDeepNet(nn.Module):
+
+    name = 'deepnet'
+
+    def __init__(self):
+        super(AfrozeDeepNet, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(1536, 256)
+        self.fc2 = nn.Linear(256, 256)
+        self.fc3 = nn.Linear(256, 2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = x.view(-1, 1536)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
+class DeepNet2(nn.Module):
+    name = 'deepnet2'
+
+    def __init__(self):
+        super(DeepNet2, self).__init__()
+        self.nb_channels = 512
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(16 * self.nb_channels, 512)
+        self.fc2 = nn.Linear(512, 512)
+        self.fc3 = nn.Linear(512, 2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = x.view(-1, 16 * self.nb_channels)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
+class DeepNet3(nn.Module):
+    name = 'deepnet3'
+
+    def __init__(self):
+        super(DeepNet3, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(2048, 256)
+        self.fc2 = nn.Linear(256, 256)
+        self.fc3 = nn.Linear(256, 2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv6(x)
+        x = fn.relu(x)
+
+        x = self.conv7(x)
+        x = fn.relu(x)
+
+        x = x.view(-1, 2048)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
+def nb_errors(model, data_set, mistake_filename_pattern = None):
+    ne = 0
+    for b in range(0, data_set.nb_batches):
+        input, target = data_set.get_batch(b)
+        output = model.forward(Variable(input))
+        wta_prediction = output.data.max(1)[1].view(-1)
+
+        for i in range(0, data_set.batch_size):
+            if wta_prediction[i] != target[i]:
+                ne = ne + 1
+                if mistake_filename_pattern is not None:
+                    img = input[i].clone()
+                    img.sub_(img.min())
+                    img.div_(img.max())
+                    k = b * data_set.batch_size + i
+                    filename = mistake_filename_pattern.format(k, target[i])
+                    torchvision.utils.save_image(img, filename)
+                    print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
+    return ne
+
+######################################################################
+
+def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
@@ -133,7 +377,9 @@ def train_model(model, train_set):
 
     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
 
 
     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
 
-    for e in range(0, args.nb_epochs):
+    start_t = time.time()
+
+    for e in range(nb_epochs_done, args.nb_epochs):
         acc_loss = 0.0
         for b in range(0, train_set.nb_batches):
             input, target = train_set.get_batch(b)
         acc_loss = 0.0
         for b in range(0, train_set.nb_batches):
             input, target = train_set.get_batch(b)
@@ -143,24 +389,27 @@ def train_model(model, train_set):
             model.zero_grad()
             loss.backward()
             optimizer.step()
             model.zero_grad()
             loss.backward()
             optimizer.step()
-        log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
+        dt = (time.time() - start_t) / (e + 1)
 
 
-    return model
+        log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
+                   ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
 
-######################################################################
+        torch.save([ model.state_dict(), e + 1 ], model_filename)
 
 
-def nb_errors(model, data_set):
-    ne = 0
-    for b in range(0, data_set.nb_batches):
-        input, target = data_set.get_batch(b)
-        output = model.forward(Variable(input))
-        wta_prediction = output.data.max(1)[1].view(-1)
+        if validation_set is not None:
+            nb_validation_errors = nb_errors(model, validation_set)
 
 
-        for i in range(0, data_set.batch_size):
-            if wta_prediction[i] != target[i]:
-                ne = ne + 1
+            log_string('validation_error {:.02f}% {:d} {:d}'.format(
+                100 * nb_validation_errors / validation_set.nb_samples,
+                nb_validation_errors,
+                validation_set.nb_samples)
+            )
 
 
-    return ne
+            if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
+                log_string('below validation_error_threshold')
+                break
+
+    return model
 
 ######################################################################
 
 
 ######################################################################
 
@@ -169,55 +418,132 @@ for arg in vars(args):
 
 ######################################################################
 
 
 ######################################################################
 
-for problem_number in range(1, 24):
+def int_to_suffix(n):
+    if n >= 1000000 and n%1000000 == 0:
+        return str(n//1000000) + 'M'
+    elif n >= 1000 and n%1000 == 0:
+        return str(n//1000) + 'K'
+    else:
+        return str(n)
+
+class vignette_logger():
+    def __init__(self, delay_min = 60):
+        self.start_t = time.time()
+        self.last_t = self.start_t
+        self.delay_min = delay_min
+
+    def __call__(self, n, m):
+        t = time.time()
+        if t > self.last_t + self.delay_min:
+            dt = (t - self.start_t) / m
+            log_string('sample_generation {:d} / {:d}'.format(
+                m,
+                n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
+            )
+            self.last_t = t
+
+def save_exemplar_vignettes(data_set, nb, name):
+    n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+    for k in range(0, nb):
+        b = n[k] // data_set.batch_size
+        m = n[k] % data_set.batch_size
+        i, t = data_set.get_batch(b)
+        i = i[m].float()
+        i.sub_(i.min())
+        i.div_(i.max())
+        if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+        patchwork[k].copy_(i)
+
+    torchvision.utils.save_image(patchwork, name)
 
 
-    model = AfrozeShallowNet()
+######################################################################
 
 
-    if torch.cuda.is_available():
-        model.cuda()
+if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
+    print('The number of samples must be a multiple of the batch size.')
+    raise
+
+if args.compress_vignettes:
+    log_string('using_compressed_vignettes')
+    VignetteSet = svrtset.CompressedVignetteSet
+else:
+    log_string('using_uncompressed_vignettes')
+    VignetteSet = svrtset.VignetteSet
+
+########################################
+model_class = None
+for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
+    if args.model == m.name:
+        model_class = m
+        break
+if model_class is None:
+    print('Unknown model ' + args.model)
+    raise
+
+log_string('using model class ' + m.name)
+########################################
+
+for problem_number in map(int, args.problems.split(',')):
 
 
-    model_filename = model.name + '_' + \
-                     str(problem_number) + '_' + \
-                     str(args.nb_train_batches) + '.param'
+    log_string('############### problem ' + str(problem_number) + ' ###############')
+
+    model = model_class()
+
+    if torch.cuda.is_available(): model.cuda()
+
+    model_filename = model.name + '_pb:' + \
+                     str(problem_number) + '_ns:' + \
+                     int_to_suffix(args.nb_train_samples) + '.pth'
 
     nb_parameters = 0
     for p in model.parameters(): nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
 
     nb_parameters = 0
     for p in model.parameters(): nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
-    need_to_train = False
+    ##################################################
+    # Tries to load the model
+
     try:
     try:
-        model.load_state_dict(torch.load(model_filename))
+        model_state_dict, nb_epochs_done = torch.load(model_filename)
+        model.load_state_dict(model_state_dict)
         log_string('loaded_model ' + model_filename)
     except:
         log_string('loaded_model ' + model_filename)
     except:
-        need_to_train = True
+        nb_epochs_done = 0
 
 
-    if need_to_train:
+
+    ##################################################
+    # Train if necessary
+
+    if nb_epochs_done < args.nb_epochs:
 
         log_string('training_model ' + model_filename)
 
         t = time.time()
 
 
         log_string('training_model ' + model_filename)
 
         t = time.time()
 
-        if args.compress_vignettes:
-            train_set = CompressedVignetteSet(problem_number,
-                                              args.nb_train_batches, args.batch_size,
-                                              cuda=torch.cuda.is_available())
-            test_set = CompressedVignetteSet(problem_number,
-                                             args.nb_test_batches, args.batch_size,
-                                             cuda=torch.cuda.is_available())
-        else:
-            train_set = VignetteSet(problem_number,
-                                    args.nb_train_batches, args.batch_size,
-                                    cuda=torch.cuda.is_available())
-            test_set = VignetteSet(problem_number,
-                                   args.nb_test_batches, args.batch_size,
-                                   cuda=torch.cuda.is_available())
+        train_set = VignetteSet(problem_number,
+                                args.nb_train_samples, args.batch_size,
+                                cuda = torch.cuda.is_available(),
+                                logger = vignette_logger())
 
         log_string('data_generation {:0.2f} samples / s'.format(
 
         log_string('data_generation {:0.2f} samples / s'.format(
-            (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
+            train_set.nb_samples / (time.time() - t))
         )
 
         )
 
-        train_model(model, train_set)
-        torch.save(model.state_dict(), model_filename)
+        if args.nb_exemplar_vignettes > 0:
+            save_exemplar_vignettes(train_set, args.nb_exemplar_vignettes,
+                                    'exemplar_{:d}.png'.format(problem_number))
+
+        if args.validation_error_threshold > 0.0:
+            validation_set = VignetteSet(problem_number,
+                                         args.nb_validation_samples, args.batch_size,
+                                         cuda = torch.cuda.is_available(),
+                                         logger = vignette_logger())
+        else:
+            validation_set = None
+
+        train_model(model, model_filename,
+                    train_set, validation_set,
+                    nb_epochs_done = nb_epochs_done)
+
         log_string('saved_model ' + model_filename)
 
         nb_train_errors = nb_errors(model, train_set)
         log_string('saved_model ' + model_filename)
 
         nb_train_errors = nb_errors(model, train_set)
@@ -229,7 +555,19 @@ for problem_number in range(1, 24):
             train_set.nb_samples)
         )
 
             train_set.nb_samples)
         )
 
-        nb_test_errors = nb_errors(model, test_set)
+    ##################################################
+    # Test if necessary
+
+    if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
+
+        t = time.time()
+
+        test_set = VignetteSet(problem_number,
+                               args.nb_test_samples, args.batch_size,
+                               cuda = torch.cuda.is_available())
+
+        nb_test_errors = nb_errors(model, test_set,
+                                   mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
             problem_number,
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
             problem_number,