From 0b891219d91e981f96e5321bcf0db6c3beea0017 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 16 Jun 2017 07:54:04 +0200 Subject: [PATCH] Pass the use of cuda to the VignetteSet constructor. --- cnn-svrt.py | 16 ++++++++++++---- vignette_set.py | 10 ++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cnn-svrt.py b/cnn-svrt.py index a2ab1a3..8b8ec12 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -27,6 +27,8 @@ import math from colorama import Fore, Back, Style +# Pytorch + import torch from torch import optim @@ -36,6 +38,8 @@ from torch import nn from torch.nn import functional as fn from torchvision import datasets, transforms, utils +# SVRT + from vignette_set import VignetteSet, CompressedVignetteSet ###################################################################### @@ -165,11 +169,15 @@ for arg in vars(args): for problem_number in range(1, 24): if args.compress_vignettes: - train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size) - test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size) + train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size, + cuda=torch.cuda.is_available()) + test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size, + cuda=torch.cuda.is_available()) else: - train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size) - test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size) + train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size, + cuda=torch.cuda.is_available()) + test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size, + cuda=torch.cuda.is_available()) model = AfrozeShallowNet() diff --git a/vignette_set.py b/vignette_set.py index 0ed3d39..695fed3 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -31,7 +31,8 @@ import svrt ###################################################################### class VignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number self.nb_batches = nb_batches @@ -46,7 +47,7 @@ class VignetteSet: target = torch.LongTensor(self.batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if torch.cuda.is_available(): + if self.cuda: input = input.cuda() target = target.cuda() acc += input.sum() / input.numel() @@ -65,7 +66,8 @@ class VignetteSet: ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches, batch_size): + def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number self.nb_batches = nb_batches @@ -91,7 +93,7 @@ class CompressedVignetteSet: input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std) target = self.targets[b] - if torch.cuda.is_available(): + if self.cuda: input = input.cuda() target = target.cuda() -- 2.20.1