X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=0d4b31364832ffed5005a29fb1b2b627f74cfd17;hb=91b12f8980a69a99fd6bbdc9b6f6a422dd8cd15a;hp=c7e0585abd009c5efb26913ee214ac74ec22eb41;hpb=b7c9b813a879742e1a2ac359c46c0fb6335455cf;p=pysvrt.git
diff --git a/cnn-svrt.py b/cnn-svrt.py
index c7e0585..0d4b313 100755
--- a/cnn-svrt.py
+++ b/cnn-svrt.py
@@ -19,11 +19,12 @@
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
-# along with selector. If not, see .
+# along with svrt. If not, see .
import time
import argparse
import math
+import distutils.util
from colorama import Fore, Back, Style
@@ -40,45 +41,40 @@ from torchvision import datasets, transforms, utils
# SVRT
-from vignette_set import VignetteSet, CompressedVignetteSet
+import vignette_set
######################################################################
parser = argparse.ArgumentParser(
- description = 'Simple convnet test on the SVRT.',
+ description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
-parser.add_argument('--nb_train_batches',
- type = int, default = 1000,
- help = 'How many samples for train')
+parser.add_argument('--nb_train_samples',
+ type = int, default = 100000)
-parser.add_argument('--nb_test_batches',
- type = int, default = 100,
- help = 'How many samples for test')
+parser.add_argument('--nb_test_samples',
+ type = int, default = 10000)
parser.add_argument('--nb_epochs',
- type = int, default = 50,
- help = 'How many training epochs')
+ type = int, default = 50)
parser.add_argument('--batch_size',
- type = int, default = 100,
- help = 'Mini-batch size')
+ type = int, default = 100)
parser.add_argument('--log_file',
- type = str, default = 'default.log',
- help = 'Log file name')
+ type = str, default = 'default.log')
parser.add_argument('--compress_vignettes',
- action='store_true', default = False,
+ type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
parser.add_argument('--deep_model',
- action='store_true', default = False,
+ type = distutils.util.strtobool, default = 'True',
help = 'Use Afroze\'s Alexnet-like deep model')
parser.add_argument('--test_loaded_models',
- action='store_true', default = False,
+ type = distutils.util.strtobool, default = 'False',
help = 'Should we compute the test errors of loaded models')
args = parser.parse_args()
@@ -104,10 +100,10 @@ def log_string(s, remark = ''):
pred_log_t = t
- s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
- log_file.write(s + '\n')
+ log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
log_file.flush()
- print(s + Fore.CYAN + remark + Style.RESET_ALL)
+
+ print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
######################################################################
@@ -148,22 +144,6 @@ class AfrozeShallowNet(nn.Module):
# Afroze's DeepNet
-# map size nb. maps
-# ----------------------
-# input 128x128 1
-# -- conv(21x21 x 32 stride=4) -> 28x28 32
-# -- max(2x2) -> 14x14 6
-# -- conv(7x7 x 96) -> 8x8 16
-# -- max(2x2) -> 4x4 16
-# -- conv(5x5 x 96) -> 26x36 16
-# -- conv(3x3 x 128) -> 36x36 16
-# -- conv(3x3 x 128) -> 36x36 16
-
-# -- conv(5x5 x 120) -> 1x1 120
-# -- reshape -> 120 1
-# -- full(3x84) -> 84 1
-# -- full(84x2) -> 2 1
-
class AfrozeDeepNet(nn.Module):
def __init__(self):
super(AfrozeDeepNet, self).__init__()
@@ -259,26 +239,49 @@ for arg in vars(args):
######################################################################
+def int_to_suffix(n):
+ if n >= 1000000 and n%1000000 == 0:
+ return str(n//1000000) + 'M'
+ elif n >= 1000 and n%1000 == 0:
+ return str(n//1000) + 'K'
+ else:
+ return str(n)
+
+######################################################################
+
+if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
+ print('The number of samples must be a multiple of the batch size.')
+ raise
+
+if args.compress_vignettes:
+ log_string('using_compressed_vignettes')
+ VignetteSet = vignette_set.CompressedVignetteSet
+else:
+ log_string('using_uncompressed_vignettes')
+ VignetteSet = vignette_set.VignetteSet
+
for problem_number in range(1, 24):
- log_string('**** problem ' + str(problem_number) + ' ****')
+ log_string('############### problem ' + str(problem_number) + ' ###############')
if args.deep_model:
model = AfrozeDeepNet()
else:
model = AfrozeShallowNet()
- if torch.cuda.is_available():
- model.cuda()
+ if torch.cuda.is_available(): model.cuda()
- model_filename = model.name + '_' + \
- str(problem_number) + '_' + \
- str(args.nb_train_batches) + '.param'
+ model_filename = model.name + '_pb:' + \
+ str(problem_number) + '_ns:' + \
+ int_to_suffix(args.nb_train_samples) + '.param'
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
+ ##################################################
+ # Tries to load the model
+
need_to_train = False
try:
model.load_state_dict(torch.load(model_filename))
@@ -286,22 +289,22 @@ for problem_number in range(1, 24):
except:
need_to_train = True
+ ##################################################
+ # Train if necessary
+
if need_to_train:
log_string('training_model ' + model_filename)
t = time.time()
- if args.compress_vignettes:
- train_set = CompressedVignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
- else:
- train_set = VignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ train_set = VignetteSet(problem_number,
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ train_set.nb_samples / (time.time() - t))
+ )
train_model(model, train_set)
torch.save(model.state_dict(), model_filename)
@@ -316,20 +319,20 @@ for problem_number in range(1, 24):
train_set.nb_samples)
)
+ ##################################################
+ # Test if necessary
+
if need_to_train or args.test_loaded_models:
t = time.time()
- if args.compress_vignettes:
- test_set = CompressedVignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
- else:
- test_set = VignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ test_set = VignetteSet(problem_number,
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ test_set.nb_samples / (time.time() - t))
+ )
nb_test_errors = nb_errors(model, test_set)