X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=338e145c52197e5f4649f031a83a88870473de12;hp=d0704fff48c85fca80c5723c20cb369c0600a013;hb=664f39333e0ae1ed2dccc6ec15e6c458dc8af935;hpb=4c77eebce3c3914a58c548c606d045efdae2284a diff --git a/cnn-svrt.py b/cnn-svrt.py index d0704ff..338e145 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -248,12 +248,13 @@ class DeepNet2(nn.Module): def __init__(self): super(DeepNet2, self).__init__() + self.nb_channels = 512 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3) - self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2) - self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) - self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) - self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) - self.fc1 = nn.Linear(4096, 512) + self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1) + self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1) + self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1) + self.fc1 = nn.Linear(16 * self.nb_channels, 512) self.fc2 = nn.Linear(512, 512) self.fc3 = nn.Linear(512, 2) @@ -276,7 +277,7 @@ class DeepNet2(nn.Module): x = fn.max_pool2d(x, kernel_size=2) x = fn.relu(x) - x = x.view(-1, 4096) + x = x.view(-1, 16 * self.nb_channels) x = self.fc1(x) x = fn.relu(x) @@ -539,7 +540,10 @@ for problem_number in map(int, args.problems.split(',')): else: validation_set = None - train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done) + train_model(model, model_filename, + train_set, validation_set, + nb_epochs_done = nb_epochs_done) + log_string('saved_model ' + model_filename) nb_train_errors = nb_errors(model, train_set)