X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=8b8ec124e0545feb873d059491032c4277159299;hp=a2ab1a31ce7dfc903b1b2a342b5d888f6c188045;hb=0b891219d91e981f96e5321bcf0db6c3beea0017;hpb=ea951479345890206211764657ce4d9556af9e76 diff --git a/cnn-svrt.py b/cnn-svrt.py index a2ab1a3..8b8ec12 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -27,6 +27,8 @@ import math from colorama import Fore, Back, Style +# Pytorch + import torch from torch import optim @@ -36,6 +38,8 @@ from torch import nn from torch.nn import functional as fn from torchvision import datasets, transforms, utils +# SVRT + from vignette_set import VignetteSet, CompressedVignetteSet ###################################################################### @@ -165,11 +169,15 @@ for arg in vars(args): for problem_number in range(1, 24): if args.compress_vignettes: - train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size) - test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size) + train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size, + cuda=torch.cuda.is_available()) + test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size, + cuda=torch.cuda.is_available()) else: - train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size) - test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size) + train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size, + cuda=torch.cuda.is_available()) + test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size, + cuda=torch.cuda.is_available()) model = AfrozeShallowNet()