X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=7bef242de186a1bbcbad4a84444c5c3311a9445e;hp=6645ac15a55a44d40edbbde18115b8dcb9ca029c;hb=08ef6b7c332153cd72b7a225e27ee7af8882f313;hpb=61e13c9a3cba66d1b6dafaa14efb71e979b8af08 diff --git a/cnn-svrt.py b/cnn-svrt.py index 6645ac1..7bef242 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -171,15 +171,15 @@ for arg in vars(args): for problem_number in range(1, 24): - model_filename = model.name + '_' + \ - str(problem_number) + '_' + \ - str(args.nb_train_batches) + '.param' - model = AfrozeShallowNet() if torch.cuda.is_available(): model.cuda() + model_filename = model.name + '_' + \ + str(problem_number) + '_' + \ + str(args.nb_train_batches) + '.param' + nb_parameters = 0 for p in model.parameters(): nb_parameters += p.numel() log_string('nb_parameters {:d}'.format(nb_parameters))