Pass the use of cuda to the VignetteSet constructor.
[pysvrt.git] / cnn-svrt.py
index 084606a..8b8ec12 100755 (executable)
@@ -27,6 +27,8 @@ import math
 
 from colorama import Fore, Back, Style
 
+# Pytorch
+
 import torch
 
 from torch import optim
@@ -36,6 +38,8 @@ from torch import nn
 from torch.nn import functional as fn
 from torchvision import datasets, transforms, utils
 
+# SVRT
+
 from vignette_set import VignetteSet, CompressedVignetteSet
 
 ######################################################################
@@ -107,6 +111,7 @@ class AfrozeShallowNet(nn.Module):
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
+        self.name = 'shallownet'
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
@@ -117,6 +122,8 @@ class AfrozeShallowNet(nn.Module):
         x = self.fc2(x)
         return x
 
+######################################################################
+
 def train_model(model, train_set):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
@@ -162,11 +169,15 @@ for arg in vars(args):
 
 for problem_number in range(1, 24):
     if args.compress_vignettes:
-        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                         cuda=torch.cuda.is_available())
     else:
-        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
 
     model = AfrozeShallowNet()
 
@@ -178,7 +189,7 @@ for problem_number in range(1, 24):
         nb_parameters += p.numel()
     log_string('nb_parameters {:d}'.format(nb_parameters))
 
-    model_filename = 'model_' + str(problem_number) + '.param'
+    model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param'
 
     try:
         model.load_state_dict(torch.load(model_filename))