X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=227d9b44620a78827e40176ce46a96c5572522a7;hb=95146c1d3c5954302284d45dcc3c6da26eaee253;hp=e5ecf768bfdc060b255c6612649f3f2ede62a51a;hpb=9eec6d457d017e0204cc80c0e1b24f894d064267;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index e5ecf76..227d9b4 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -25,18 +25,22 @@ import time import argparse import math import distutils.util +import re from colorama import Fore, Back, Style # Pytorch import torch +import torchvision from torch import optim +from torch import multiprocessing from torch import FloatTensor as Tensor from torch.autograd import Variable from torch import nn from torch.nn import functional as fn + from torchvision import datasets, transforms, utils # SVRT @@ -56,6 +60,13 @@ parser.add_argument('--nb_train_samples', parser.add_argument('--nb_test_samples', type = int, default = 10000) +parser.add_argument('--nb_validation_samples', + type = int, default = 10000) + +parser.add_argument('--validation_error_threshold', + type = float, default = 0.0, + help = 'Early training termination criterion') + parser.add_argument('--nb_epochs', type = int, default = 50) @@ -65,13 +76,16 @@ parser.add_argument('--batch_size', parser.add_argument('--log_file', type = str, default = 'default.log') +parser.add_argument('--nb_exemplar_vignettes', + type = int, default = 32) + parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') -parser.add_argument('--deep_model', - type = distutils.util.strtobool, default = 'True', - help = 'Use Afroze\'s Alexnet-like deep model') +parser.add_argument('--model', + type = str, default = 'deepnet', + help = 'What model to use') parser.add_argument('--test_loaded_models', type = distutils.util.strtobool, default = 'False', @@ -87,13 +101,15 @@ args = parser.parse_args() log_file = open(args.log_file, 'a') pred_log_t = None +last_tag_t = time.time() print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) # Log and prints the string, with a time stamp. Does not log the # remark + def log_string(s, remark = ''): - global pred_log_t + global pred_log_t, last_tag_t t = time.time() @@ -104,10 +120,14 @@ def log_string(s, remark = ''): pred_log_t = t - log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n') + if t > last_tag_t + 3600: + last_tag_t = t + print(Fore.RED + time.ctime() + Style.RESET_ALL) + + log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) ###################################################################### @@ -126,6 +146,8 @@ def log_string(s, remark = ''): # -- full(84x2) -> 2 1 class AfrozeShallowNet(nn.Module): + name = 'shallownet' + def __init__(self): super(AfrozeShallowNet, self).__init__() self.conv1 = nn.Conv2d(1, 6, kernel_size=21) @@ -133,7 +155,6 @@ class AfrozeShallowNet(nn.Module): self.conv3 = nn.Conv2d(16, 120, kernel_size=18) self.fc1 = nn.Linear(120, 84) self.fc2 = nn.Linear(84, 2) - self.name = 'shallownet' def forward(self, x): x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2)) @@ -149,6 +170,9 @@ class AfrozeShallowNet(nn.Module): # Afroze's DeepNet class AfrozeDeepNet(nn.Module): + + name = 'deepnet' + def __init__(self): super(AfrozeDeepNet, self).__init__() self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) @@ -159,7 +183,6 @@ class AfrozeDeepNet(nn.Module): self.fc1 = nn.Linear(1536, 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 2) - self.name = 'deepnet' def forward(self, x): x = self.conv1(x) @@ -194,7 +217,69 @@ class AfrozeDeepNet(nn.Module): ###################################################################### -def train_model(model, train_set): +class DeepNet2(nn.Module): + name = 'deepnet2' + + def __init__(self): + super(DeepNet2, self).__init__() + self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) + self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.fc1 = nn.Linear(2048, 512) + self.fc2 = nn.Linear(512, 512) + self.fc3 = nn.Linear(256, 2) + + def forward(self, x): + x = self.conv1(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv2(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv3(x) + x = fn.relu(x) + + x = self.conv4(x) + x = fn.relu(x) + + x = self.conv5(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = x.view(-1, 1536) + + x = self.fc1(x) + x = fn.relu(x) + + x = self.fc2(x) + x = fn.relu(x) + + x = self.fc3(x) + + return x + +###################################################################### + +def nb_errors(model, data_set): + ne = 0 + for b in range(0, data_set.nb_batches): + input, target = data_set.get_batch(b) + output = model.forward(Variable(input)) + wta_prediction = output.data.max(1)[1].view(-1) + + for i in range(0, data_set.batch_size): + if wta_prediction[i] != target[i]: + ne = ne + 1 + + return ne + +###################################################################### + +def train_model(model, train_set, validation_set): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -216,25 +301,24 @@ def train_model(model, train_set): loss.backward() optimizer.step() dt = (time.time() - start_t) / (e + 1) + log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') - return model + if validation_set is not None: + nb_validation_errors = nb_errors(model, validation_set) -###################################################################### - -def nb_errors(model, data_set): - ne = 0 - for b in range(0, data_set.nb_batches): - input, target = data_set.get_batch(b) - output = model.forward(Variable(input)) - wta_prediction = output.data.max(1)[1].view(-1) + log_string('validation_error {:.02f}% {:d} {:d}'.format( + 100 * nb_validation_errors / validation_set.nb_samples, + nb_validation_errors, + validation_set.nb_samples) + ) - for i in range(0, data_set.batch_size): - if wta_prediction[i] != target[i]: - ne = ne + 1 + if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold: + log_string('below validation_error_threshold') + break - return ne + return model ###################################################################### @@ -267,6 +351,21 @@ class vignette_logger(): ) self.last_t = t +def save_examplar_vignettes(data_set, nb, name): + n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb) + + for k in range(0, nb): + b = n[k] // data_set.batch_size + m = n[k] % data_set.batch_size + i, t = data_set.get_batch(b) + i = i[m].float() + i.sub_(i.min()) + i.div_(i.max()) + if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2)) + patchwork[k].copy_(i) + + torchvision.utils.save_image(patchwork, name) + ###################################################################### if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0: @@ -282,14 +381,24 @@ else: log_string('using_uncompressed_vignettes') VignetteSet = svrtset.VignetteSet +######################################## +model_class = None +for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]: + if args.model == m.name: + model_class = m + break +if model_class is None: + print('Unknown model ' + args.model) + raise + +log_string('using model class ' + m.name) +######################################## + for problem_number in map(int, args.problems.split(',')): log_string('############### problem ' + str(problem_number) + ' ###############') - if args.deep_model: - model = AfrozeDeepNet() - else: - model = AfrozeShallowNet() + model = model_class() if torch.cuda.is_available(): model.cuda() @@ -329,7 +438,19 @@ for problem_number in map(int, args.problems.split(',')): train_set.nb_samples / (time.time() - t)) ) - train_model(model, train_set) + if args.nb_exemplar_vignettes > 0: + save_examplar_vignettes(train_set, args.nb_exemplar_vignettes, + 'examplar_{:d}.png'.format(problem_number)) + + if args.validation_error_threshold > 0.0: + validation_set = VignetteSet(problem_number, + args.nb_validation_samples, args.batch_size, + cuda = torch.cuda.is_available(), + logger = vignette_logger()) + else: + validation_set = None + + train_model(model, train_set, validation_set) torch.save(model.state_dict(), model_filename) log_string('saved_model ' + model_filename) @@ -353,10 +474,6 @@ for problem_number in map(int, args.problems.split(',')): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - log_string('data_generation {:0.2f} samples / s'.format( - test_set.nb_samples / (time.time() - t)) - ) - nb_test_errors = nb_errors(model, test_set) log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(