X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=3fe50d8833740354ffe9595c268233adb36a4384;hb=ffe0b4fed11bb356684d9faa1849c86997a3029a;hp=8baaacbc4fe7b4b37d8e6be27dd229b5b44bf6cc;hpb=349b55a2d9ca213718df8941058d42689ba68163;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 8baaacb..3fe50d8 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -24,8 +24,10 @@ import time import argparse import math + import distutils.util import re +import signal from colorama import Fore, Back, Style @@ -83,6 +85,9 @@ parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') +parser.add_argument('--save_test_mistakes', + type = distutils.util.strtobool, default = 'False') + parser.add_argument('--model', type = str, default = 'deepnet', help = 'What model to use') @@ -127,7 +132,24 @@ def log_string(s, remark = ''): log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \ + + Style.RESET_ALL + + ' ' \ + + s + Fore.CYAN + remark \ + + Style.RESET_ALL) + +###################################################################### + +def handler_sigint(signum, frame): + log_string('got sigint') + exit(0) + +def handler_sigterm(signum, frame): + log_string('got sigterm') + exit(0) + +signal.signal(signal.SIGINT, handler_sigint) +signal.signal(signal.SIGTERM, handler_sigterm) ###################################################################### @@ -319,7 +341,7 @@ class DeepNet3(nn.Module): ###################################################################### -def nb_errors(model, data_set): +def nb_errors(model, data_set, mistake_filename_pattern = None): ne = 0 for b in range(0, data_set.nb_batches): input, target = data_set.get_batch(b) @@ -329,6 +351,12 @@ def nb_errors(model, data_set): for i in range(0, data_set.batch_size): if wta_prediction[i] != target[i]: ne = ne + 1 + if mistake_filename_pattern is not None: + img = input[i].clone() + img.sub_(img.min()) + img.div_(img.max()) + torchvision.utils.save_image(img, + mistake_filename_pattern.format(b + i, target[i])) return ne @@ -531,7 +559,8 @@ for problem_number in map(int, args.problems.split(',')): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - nb_test_errors = nb_errors(model, test_set) + nb_test_errors = nb_errors(model, test_set, + mistake_filename_pattern = 'mistake_{:d}_{:06d}.png') log_string('test_error {:d} {:.02f}% {:d} {:d}'.format( problem_number,