X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=cnn-svrt.py;h=1511e823dfe6f71458cc7ba541e51c932efd9794;hb=33ce0facd62e208f04aec201d39e4ed3b989a830;hp=227d9b44620a78827e40176ce46a96c5572522a7;hpb=95146c1d3c5954302284d45dcc3c6da26eaee253;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 227d9b4..1511e82 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -229,7 +229,7 @@ class DeepNet2(nn.Module): self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.fc1 = nn.Linear(2048, 512) self.fc2 = nn.Linear(512, 512) - self.fc3 = nn.Linear(256, 2) + self.fc3 = nn.Linear(512, 2) def forward(self, x): x = self.conv1(x) @@ -250,7 +250,7 @@ class DeepNet2(nn.Module): x = fn.max_pool2d(x, kernel_size=2) x = fn.relu(x) - x = x.view(-1, 1536) + x = x.view(-1, 2048) x = self.fc1(x) x = fn.relu(x)