X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=f3d350eb9ea9203f408a0603e7c0458b88801e95;hb=212e14ec93489fcaa8a039c4bc64abeb8852c5ec;hp=153bdc9d23a18a7abe67cfbe3f72246a5ee2fa83;hpb=d21f7d8eecb12aa4cc60360db6aa33324327e987;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 153bdc9..f3d350e 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -19,11 +19,12 @@ # General Public License for more details. # # You should have received a copy of the GNU General Public License -# along with selector. If not, see . +# along with svrt. If not, see . import time import argparse import math +import distutils.util from colorama import Fore, Back, Style @@ -40,7 +41,7 @@ from torchvision import datasets, transforms, utils # SVRT -import vignette_set +import svrtset ###################################################################### @@ -55,6 +56,13 @@ parser.add_argument('--nb_train_samples', parser.add_argument('--nb_test_samples', type = int, default = 10000) +parser.add_argument('--nb_validation_samples', + type = int, default = 10000) + +parser.add_argument('--validation_error_threshold', + type = float, default = 0.0, + help = 'Early training termination criterion') + parser.add_argument('--nb_epochs', type = int, default = 50) @@ -65,22 +73,26 @@ parser.add_argument('--log_file', type = str, default = 'default.log') parser.add_argument('--compress_vignettes', - action='store_true', default = True, + type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') parser.add_argument('--deep_model', - action='store_true', default = True, + type = distutils.util.strtobool, default = 'True', help = 'Use Afroze\'s Alexnet-like deep model') parser.add_argument('--test_loaded_models', - action='store_true', default = False, + type = distutils.util.strtobool, default = 'False', help = 'Should we compute the test errors of loaded models') +parser.add_argument('--problems', + type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23', + help = 'What problems to process') + args = parser.parse_args() ###################################################################### -log_file = open(args.log_file, 'w') +log_file = open(args.log_file, 'a') pred_log_t = None print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) @@ -143,22 +155,6 @@ class AfrozeShallowNet(nn.Module): # Afroze's DeepNet -# map size nb. maps -# ---------------------- -# input 128x128 1 -# -- conv(21x21 x 32 stride=4) -> 28x28 32 -# -- max(2x2) -> 14x14 6 -# -- conv(7x7 x 96) -> 8x8 16 -# -- max(2x2) -> 4x4 16 -# -- conv(5x5 x 96) -> 26x36 16 -# -- conv(3x3 x 128) -> 36x36 16 -# -- conv(3x3 x 128) -> 36x36 16 - -# -- conv(5x5 x 120) -> 1x1 120 -# -- reshape -> 120 1 -# -- full(3x84) -> 84 1 -# -- full(84x2) -> 2 1 - class AfrozeDeepNet(nn.Module): def __init__(self): super(AfrozeDeepNet, self).__init__() @@ -205,7 +201,22 @@ class AfrozeDeepNet(nn.Module): ###################################################################### -def train_model(model, train_set): +def nb_errors(model, data_set): + ne = 0 + for b in range(0, data_set.nb_batches): + input, target = data_set.get_batch(b) + output = model.forward(Variable(input)) + wta_prediction = output.data.max(1)[1].view(-1) + + for i in range(0, data_set.batch_size): + if wta_prediction[i] != target[i]: + ne = ne + 1 + + return ne + +###################################################################### + +def train_model(model, train_set, validation_set): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -227,25 +238,24 @@ def train_model(model, train_set): loss.backward() optimizer.step() dt = (time.time() - start_t) / (e + 1) + log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') - return model - -###################################################################### + if validation_set is not None: + nb_validation_errors = nb_errors(model, validation_set) -def nb_errors(model, data_set): - ne = 0 - for b in range(0, data_set.nb_batches): - input, target = data_set.get_batch(b) - output = model.forward(Variable(input)) - wta_prediction = output.data.max(1)[1].view(-1) + log_string('validation_error {:.02f}% {:d} {:d}'.format( + 100 * nb_validation_errors / validation_set.nb_samples, + nb_validation_errors, + validation_set.nb_samples) + ) - for i in range(0, data_set.batch_size): - if wta_prediction[i] != target[i]: - ne = ne + 1 + if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold: + log_string('below validation_error_threshold') + break - return ne + return model ###################################################################### @@ -255,25 +265,45 @@ for arg in vars(args): ###################################################################### def int_to_suffix(n): - if n > 1000000 and n%1000000 == 0: + if n >= 1000000 and n%1000000 == 0: return str(n//1000000) + 'M' - elif n > 1000 and n%1000 == 0: + elif n >= 1000 and n%1000 == 0: return str(n//1000) + 'K' else: return str(n) +class vignette_logger(): + def __init__(self, delay_min = 60): + self.start_t = time.time() + self.last_t = self.start_t + self.delay_min = delay_min + + def __call__(self, n, m): + t = time.time() + if t > self.last_t + self.delay_min: + dt = (t - self.start_t) / m + log_string('sample_generation {:d} / {:d}'.format( + m, + n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']' + ) + self.last_t = t + ###################################################################### if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0: print('The number of samples must be a multiple of the batch size.') raise +log_string('############### start ###############') + if args.compress_vignettes: - VignetteSet = vignette_set.CompressedVignetteSet + log_string('using_compressed_vignettes') + VignetteSet = svrtset.CompressedVignetteSet else: - VignetteSet = vignette_set.VignetteSet + log_string('using_uncompressed_vignettes') + VignetteSet = svrtset.VignetteSet -for problem_number in range(1, 24): +for problem_number in map(int, args.problems.split(',')): log_string('############### problem ' + str(problem_number) + ' ###############') @@ -284,8 +314,8 @@ for problem_number in range(1, 24): if torch.cuda.is_available(): model.cuda() - model_filename = model.name + '_' + \ - str(problem_number) + '_' + \ + model_filename = model.name + '_pb:' + \ + str(problem_number) + '_ns:' + \ int_to_suffix(args.nb_train_samples) + '.param' nb_parameters = 0 @@ -313,13 +343,22 @@ for problem_number in range(1, 24): train_set = VignetteSet(problem_number, args.nb_train_samples, args.batch_size, - cuda = torch.cuda.is_available()) + cuda = torch.cuda.is_available(), + logger = vignette_logger()) log_string('data_generation {:0.2f} samples / s'.format( train_set.nb_samples / (time.time() - t)) ) - train_model(model, train_set) + if args.validation_error_threshold > 0.0: + validation_set = VignetteSet(problem_number, + args.nb_validation_samples, args.batch_size, + cuda = torch.cuda.is_available(), + logger = vignette_logger()) + else: + validation_set = None + + train_model(model, train_set, validation_set) torch.save(model.state_dict(), model_filename) log_string('saved_model ' + model_filename) @@ -343,10 +382,6 @@ for problem_number in range(1, 24): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - log_string('data_generation {:0.2f} samples / s'.format( - test_set.nb_samples / (time.time() - t)) - ) - nb_test_errors = nb_errors(model, test_set) log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(