X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=ade87ceea78dba435a3f9982e2e55f1bb719357e;hb=1ae0133746fd78a916ac540475c64a0e5fccd3e4;hp=227d9b44620a78827e40176ce46a96c5572522a7;hpb=95146c1d3c5954302284d45dcc3c6da26eaee253;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 227d9b4..ade87ce 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -24,8 +24,10 @@ import time import argparse import math + import distutils.util import re +import signal from colorama import Fore, Back, Style @@ -127,7 +129,24 @@ def log_string(s, remark = ''): log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n') log_file.flush() - print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL) + print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \ + + Style.RESET_ALL + + ' ' \ + + s + Fore.CYAN + remark \ + + Style.RESET_ALL) + +###################################################################### + +def handler_sigint(signum, frame): + log_string('got sigint') + exit(0) + +def handler_sigterm(signum, frame): + log_string('got sigterm') + exit(0) + +signal.signal(signal.SIGINT, handler_sigint) +signal.signal(signal.SIGTERM, handler_sigterm) ###################################################################### @@ -223,12 +242,61 @@ class DeepNet2(nn.Module): def __init__(self): super(DeepNet2, self).__init__() self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) + self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.fc1 = nn.Linear(4096, 512) + self.fc2 = nn.Linear(512, 512) + self.fc3 = nn.Linear(512, 2) + + def forward(self, x): + x = self.conv1(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv2(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = self.conv3(x) + x = fn.relu(x) + + x = self.conv4(x) + x = fn.relu(x) + + x = self.conv5(x) + x = fn.max_pool2d(x, kernel_size=2) + x = fn.relu(x) + + x = x.view(-1, 4096) + + x = self.fc1(x) + x = fn.relu(x) + + x = self.fc2(x) + x = fn.relu(x) + + x = self.fc3(x) + + return x + +###################################################################### + +class DeepNet3(nn.Module): + name = 'deepnet3' + + def __init__(self): + super(DeepNet3, self).__init__() + self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2) self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1) - self.fc1 = nn.Linear(2048, 512) - self.fc2 = nn.Linear(512, 512) + self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.fc1 = nn.Linear(2048, 256) + self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 2) def forward(self, x): @@ -250,7 +318,13 @@ class DeepNet2(nn.Module): x = fn.max_pool2d(x, kernel_size=2) x = fn.relu(x) - x = x.view(-1, 1536) + x = self.conv6(x) + x = fn.relu(x) + + x = self.conv7(x) + x = fn.relu(x) + + x = x.view(-1, 2048) x = self.fc1(x) x = fn.relu(x) @@ -279,7 +353,7 @@ def nb_errors(model, data_set): ###################################################################### -def train_model(model, train_set, validation_set): +def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0): batch_size = args.batch_size criterion = nn.CrossEntropyLoss() @@ -290,7 +364,7 @@ def train_model(model, train_set, validation_set): start_t = time.time() - for e in range(0, args.nb_epochs): + for e in range(nb_epochs_done, args.nb_epochs): acc_loss = 0.0 for b in range(0, train_set.nb_batches): input, target = train_set.get_batch(b) @@ -305,6 +379,8 @@ def train_model(model, train_set, validation_set): log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss), ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']') + torch.save([ model.state_dict(), e + 1 ], model_filename) + if validation_set is not None: nb_validation_errors = nb_errors(model, validation_set) @@ -383,7 +459,7 @@ else: ######################################## model_class = None -for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]: +for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]: if args.model == m.name: model_class = m break @@ -404,7 +480,7 @@ for problem_number in map(int, args.problems.split(',')): model_filename = model.name + '_pb:' + \ str(problem_number) + '_ns:' + \ - int_to_suffix(args.nb_train_samples) + '.param' + int_to_suffix(args.nb_train_samples) + '.state' nb_parameters = 0 for p in model.parameters(): nb_parameters += p.numel() @@ -413,17 +489,18 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Tries to load the model - need_to_train = False try: - model.load_state_dict(torch.load(model_filename)) + model_state_dict, nb_epochs_done = torch.load(model_filename) + model.load_state_dict(model_state_dict) log_string('loaded_model ' + model_filename) except: - need_to_train = True + nb_epochs_done = 0 + ################################################## # Train if necessary - if need_to_train: + if nb_epochs_done < args.nb_epochs: log_string('training_model ' + model_filename) @@ -450,8 +527,7 @@ for problem_number in map(int, args.problems.split(',')): else: validation_set = None - train_model(model, train_set, validation_set) - torch.save(model.state_dict(), model_filename) + train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done) log_string('saved_model ' + model_filename) nb_train_errors = nb_errors(model, train_set) @@ -466,7 +542,7 @@ for problem_number in map(int, args.problems.split(',')): ################################################## # Test if necessary - if need_to_train or args.test_loaded_models: + if nb_epochs_done < args.nb_epochs or args.test_loaded_models: t = time.time()