Added DeepNet3.
[pysvrt.git] / cnn-svrt.py
index d6c7169..8baaacb 100755 (executable)
@@ -35,10 +35,12 @@ import torch
 import torchvision
 
 from torch import optim
 import torchvision
 
 from torch import optim
+from torch import multiprocessing
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 # SVRT
 from torchvision import datasets, transforms, utils
 
 # SVRT
@@ -75,15 +77,15 @@ parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
 parser.add_argument('--nb_exemplar_vignettes',
                     type = str, default = 'default.log')
 
 parser.add_argument('--nb_exemplar_vignettes',
-                    type = int, default = -1)
+                    type = int, default = 32)
 
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 
 
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 
-parser.add_argument('--deep_model',
-                    type = distutils.util.strtobool, default = 'True',
-                    help = 'Use Afroze\'s Alexnet-like deep model')
+parser.add_argument('--model',
+                    type = str, default = 'deepnet',
+                    help = 'What model to use')
 
 parser.add_argument('--test_loaded_models',
                     type = distutils.util.strtobool, default = 'False',
 
 parser.add_argument('--test_loaded_models',
                     type = distutils.util.strtobool, default = 'False',
@@ -144,6 +146,8 @@ def log_string(s, remark = ''):
 # -- full(84x2)        -> 2          1
 
 class AfrozeShallowNet(nn.Module):
 # -- full(84x2)        -> 2          1
 
 class AfrozeShallowNet(nn.Module):
+    name = 'shallownet'
+
     def __init__(self):
         super(AfrozeShallowNet, self).__init__()
         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
     def __init__(self):
         super(AfrozeShallowNet, self).__init__()
         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
@@ -151,7 +155,6 @@ class AfrozeShallowNet(nn.Module):
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
-        self.name = 'shallownet'
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
@@ -167,6 +170,9 @@ class AfrozeShallowNet(nn.Module):
 # Afroze's DeepNet
 
 class AfrozeDeepNet(nn.Module):
 # Afroze's DeepNet
 
 class AfrozeDeepNet(nn.Module):
+
+    name = 'deepnet'
+
     def __init__(self):
         super(AfrozeDeepNet, self).__init__()
         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
     def __init__(self):
         super(AfrozeDeepNet, self).__init__()
         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
@@ -177,7 +183,6 @@ class AfrozeDeepNet(nn.Module):
         self.fc1 = nn.Linear(1536, 256)
         self.fc2 = nn.Linear(256, 256)
         self.fc3 = nn.Linear(256, 2)
         self.fc1 = nn.Linear(1536, 256)
         self.fc2 = nn.Linear(256, 256)
         self.fc3 = nn.Linear(256, 2)
-        self.name = 'deepnet'
 
     def forward(self, x):
         x = self.conv1(x)
 
     def forward(self, x):
         x = self.conv1(x)
@@ -212,6 +217,108 @@ class AfrozeDeepNet(nn.Module):
 
 ######################################################################
 
 
 ######################################################################
 
+class DeepNet2(nn.Module):
+    name = 'deepnet2'
+
+    def __init__(self):
+        super(DeepNet2, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(4096, 512)
+        self.fc2 = nn.Linear(512, 512)
+        self.fc3 = nn.Linear(512, 2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = x.view(-1, 4096)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
+class DeepNet3(nn.Module):
+    name = 'deepnet3'
+
+    def __init__(self):
+        super(DeepNet3, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(2048, 256)
+        self.fc2 = nn.Linear(256, 256)
+        self.fc3 = nn.Linear(256, 2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv6(x)
+        x = fn.relu(x)
+
+        x = self.conv7(x)
+        x = fn.relu(x)
+
+        x = x.view(-1, 2048)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
 def nb_errors(model, data_set):
     ne = 0
     for b in range(0, data_set.nb_batches):
 def nb_errors(model, data_set):
     ne = 0
     for b in range(0, data_set.nb_batches):
@@ -227,7 +334,7 @@ def nb_errors(model, data_set):
 
 ######################################################################
 
 
 ######################################################################
 
-def train_model(model, train_set, validation_set):
+def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
@@ -238,7 +345,7 @@ def train_model(model, train_set, validation_set):
 
     start_t = time.time()
 
 
     start_t = time.time()
 
-    for e in range(0, args.nb_epochs):
+    for e in range(nb_epochs_done, args.nb_epochs):
         acc_loss = 0.0
         for b in range(0, train_set.nb_batches):
             input, target = train_set.get_batch(b)
         acc_loss = 0.0
         for b in range(0, train_set.nb_batches):
             input, target = train_set.get_batch(b)
@@ -253,6 +360,8 @@ def train_model(model, train_set, validation_set):
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
+        torch.save([ model.state_dict(), e + 1 ], model_filename)
+
         if validation_set is not None:
             nb_validation_errors = nb_errors(model, validation_set)
 
         if validation_set is not None:
             nb_validation_errors = nb_errors(model, validation_set)
 
@@ -329,20 +438,30 @@ else:
     log_string('using_uncompressed_vignettes')
     VignetteSet = svrtset.VignetteSet
 
     log_string('using_uncompressed_vignettes')
     VignetteSet = svrtset.VignetteSet
 
+########################################
+model_class = None
+for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
+    if args.model == m.name:
+        model_class = m
+        break
+if model_class is None:
+    print('Unknown model ' + args.model)
+    raise
+
+log_string('using model class ' + m.name)
+########################################
+
 for problem_number in map(int, args.problems.split(',')):
 
     log_string('############### problem ' + str(problem_number) + ' ###############')
 
 for problem_number in map(int, args.problems.split(',')):
 
     log_string('############### problem ' + str(problem_number) + ' ###############')
 
-    if args.deep_model:
-        model = AfrozeDeepNet()
-    else:
-        model = AfrozeShallowNet()
+    model = model_class()
 
     if torch.cuda.is_available(): model.cuda()
 
     model_filename = model.name + '_pb:' + \
                      str(problem_number) + '_ns:' + \
 
     if torch.cuda.is_available(): model.cuda()
 
     model_filename = model.name + '_pb:' + \
                      str(problem_number) + '_ns:' + \
-                     int_to_suffix(args.nb_train_samples) + '.param'
+                     int_to_suffix(args.nb_train_samples) + '.state'
 
     nb_parameters = 0
     for p in model.parameters(): nb_parameters += p.numel()
 
     nb_parameters = 0
     for p in model.parameters(): nb_parameters += p.numel()
@@ -351,17 +470,18 @@ for problem_number in map(int, args.problems.split(',')):
     ##################################################
     # Tries to load the model
 
     ##################################################
     # Tries to load the model
 
-    need_to_train = False
     try:
     try:
-        model.load_state_dict(torch.load(model_filename))
+        model_state_dict, nb_epochs_done = torch.load(model_filename)
+        model.load_state_dict(model_state_dict)
         log_string('loaded_model ' + model_filename)
     except:
         log_string('loaded_model ' + model_filename)
     except:
-        need_to_train = True
+        nb_epochs_done = 0
+
 
     ##################################################
     # Train if necessary
 
 
     ##################################################
     # Train if necessary
 
-    if need_to_train:
+    if nb_epochs_done < args.nb_epochs:
 
         log_string('training_model ' + model_filename)
 
 
         log_string('training_model ' + model_filename)
 
@@ -388,8 +508,7 @@ for problem_number in map(int, args.problems.split(',')):
         else:
             validation_set = None
 
         else:
             validation_set = None
 
-        train_model(model, train_set, validation_set)
-        torch.save(model.state_dict(), model_filename)
+        train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
         log_string('saved_model ' + model_filename)
 
         nb_train_errors = nb_errors(model, train_set)
         log_string('saved_model ' + model_filename)
 
         nb_train_errors = nb_errors(model, train_set)
@@ -404,7 +523,7 @@ for problem_number in map(int, args.problems.split(',')):
     ##################################################
     # Test if necessary
 
     ##################################################
     # Test if necessary
 
-    if need_to_train or args.test_loaded_models:
+    if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
 
         t = time.time()
 
 
         t = time.time()