From: Francois Fleuret Date: Sat, 17 Jun 2017 16:56:44 +0000 (+0200) Subject: Added Afroze's DeepNet. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=commitdiff_plain;h=b7c9b813a879742e1a2ac359c46c0fb6335455cf Added Afroze's DeepNet. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index c75b336..c7e0585 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -73,6 +73,10 @@ parser.add_argument('--compress_vignettes', action='store_true', default = False, help = 'Use lossless compression to reduce the memory footprint') +parser.add_argument('--deep_model', + action='store_true', default = False, + help = 'Use Afroze\'s Alexnet-like deep model') + parser.add_argument('--test_loaded_models', action='store_true', default = False, help = 'Should we compute the test errors of loaded models') @@ -86,7 +90,9 @@ pred_log_t = None print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) -def log_string(s): +# Log and prints the string, with a time stamp. Does not log the +# remark +def log_string(s, remark = ''): global pred_log_t t = time.time() @@ -98,10 +104,10 @@ def log_string(s): pred_log_t = t - s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s log_file.write(s + '\n') log_file.flush() - print(s) + print(s + Fore.CYAN + remark + Style.RESET_ALL) ###################################################################### @@ -140,6 +146,70 @@ class AfrozeShallowNet(nn.Module): ###################################################################### +# Afroze's DeepNet + +# map size nb. maps +# ---------------------- +# input 128x128 1 +# -- conv(21x21 x 32 stride=4) -> 28x28 32 +# -- max(2x2) -> 14x14 6 +# -- conv(7x7 x 96) -> 8x8 16 +# -- max(2x2) -> 4x4 16 +# -- conv(5x5 x 96) -> 26x36 16 +# -- conv(3x3 x 128) -> 36x36 16 +# -- conv(3x3 x 128) -> 36x36 16 + +# -- conv(5x5 x 120) -> 1x1 120 +# -- reshape -> 120 1 +# -- full(3x84) -> 84 1 +# -- full(84x2) -> 2 1 + +class AfrozeDeepNet(nn.Module): + def __init__(self): + super(AfrozeDeepNet, self).__init__() + self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) + self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1) + self.fc1 = nn.Linear(1536, 256) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 2) + self.name = 'deepnet' + + def forward(self, x): + x = self.conv1(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv2(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv3(x) + x = fn.relu(x) + + x = self.conv4(x) + x = fn.relu(x) + + x = self.conv5(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = x.view(-1, 1536) + + x = self.fc1(x) + x = fn.relu(x) + + x = self.fc2(x) + x = fn.relu(x) + + x = self.fc3(x) + + return x + +###################################################################### + def train_model(model, train_set): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -161,9 +231,9 @@ def train_model(model, train_set): model.zero_grad() loss.backward() optimizer.step() - log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss)) dt = (time.time() - start_t) / (e + 1) - print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL) + log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), + ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') return model @@ -193,7 +263,10 @@ for problem_number in range(1, 24): log_string('**** problem ' + str(problem_number) + ' ****') - model = AfrozeShallowNet() + if args.deep_model: + model = AfrozeDeepNet() + else: + model = AfrozeShallowNet() if torch.cuda.is_available(): model.cuda()