projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Make the name of the saved model more explicit.
[pysvrt.git]
/
cnn-svrt.py
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
084606a
..
a2ab1a3
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-107,6
+107,7
@@
class AfrozeShallowNet(nn.Module):
self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 2)
self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 2)
+ self.name = 'shallownet'
def forward(self, x):
x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
def forward(self, x):
x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
@@
-117,6
+118,8
@@
class AfrozeShallowNet(nn.Module):
x = self.fc2(x)
return x
x = self.fc2(x)
return x
+######################################################################
+
def train_model(model, train_set):
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
def train_model(model, train_set):
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
@@
-178,7
+181,7
@@
for problem_number in range(1, 24):
nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
- model_filename =
'model_' + str(problem_number
) + '.param'
+ model_filename =
model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches
) + '.param'
try:
model.load_state_dict(torch.load(model_filename))
try:
model.load_state_dict(torch.load(model_filename))