projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Cosmetics.
[pysvrt.git]
/
cnn-svrt.py
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
338e145
..
a6b9cab
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-442,7
+442,7
@@
class vignette_logger():
)
self.last_t = t
)
self.last_t = t
-def save_ex
a
mplar_vignettes(data_set, nb, name):
+def save_ex
e
mplar_vignettes(data_set, nb, name):
n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
for k in range(0, nb):
n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
for k in range(0, nb):
@@
-493,7
+493,7
@@
for problem_number in map(int, args.problems.split(',')):
model_filename = model.name + '_pb:' + \
str(problem_number) + '_ns:' + \
model_filename = model.name + '_pb:' + \
str(problem_number) + '_ns:' + \
- int_to_suffix(args.nb_train_samples) + '.
state
'
+ int_to_suffix(args.nb_train_samples) + '.
pth
'
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
@@
-529,8
+529,8
@@
for problem_number in map(int, args.problems.split(',')):
)
if args.nb_exemplar_vignettes > 0:
)
if args.nb_exemplar_vignettes > 0:
- save_ex
a
mplar_vignettes(train_set, args.nb_exemplar_vignettes,
- 'ex
a
mplar_{:d}.png'.format(problem_number))
+ save_ex
e
mplar_vignettes(train_set, args.nb_exemplar_vignettes,
+ 'ex
e
mplar_{:d}.png'.format(problem_number))
if args.validation_error_threshold > 0.0:
validation_set = VignetteSet(problem_number,
if args.validation_error_threshold > 0.0:
validation_set = VignetteSet(problem_number,