1 #!/usr/bin/env python-for-pytorch
7 from torch import optim
8 from torch import FloatTensor as Tensor
9 from torch.autograd import Variable
11 from torch.nn import functional as fn
12 from torchvision import datasets, transforms, utils
16 ######################################################################
19 def generate_set(p, n):
20 target = torch.LongTensor(n).bernoulli_(0.5)
21 input = svrt.generate_vignettes(p, target)
22 input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
23 return Variable(input), Variable(target)
25 ######################################################################
27 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
31 super(Net, self).__init__()
32 self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
33 self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
34 self.fc1 = nn.Linear(500, 100)
35 self.fc2 = nn.Linear(100, 2)
38 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
39 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
41 x = fn.relu(self.fc1(x))
45 def train_model(train_input, train_target):
46 model, criterion = Net(), nn.CrossEntropyLoss()
48 if torch.cuda.is_available():
53 optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
55 for k in range(0, nb_epochs):
56 for b in range(0, nb_train_samples, bs):
57 output = model.forward(train_input.narrow(0, b, bs))
58 loss = criterion(output, train_target.narrow(0, b, bs))
65 ######################################################################
67 def print_test_error(model, test_input, test_target):
71 for b in range(0, nb_test_samples, bs):
72 output = model.forward(test_input.narrow(0, b, bs))
73 _, wta = torch.max(output.data, 1)
75 for i in range(0, bs):
76 if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
77 nb_test_errors = nb_test_errors + 1
79 print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
80 100 * nb_test_errors / nb_test_samples,
85 ######################################################################
87 nb_train_samples = 100000
88 nb_test_samples = 10000
90 for p in range(1, 24):
91 print('-- PROBLEM #{:d} --'.format(p))
94 train_input, train_target = generate_set(p, nb_train_samples)
95 test_input, test_target = generate_set(p, nb_test_samples)
96 if torch.cuda.is_available():
97 train_input, train_target = train_input.cuda(), train_target.cuda()
98 test_input, test_target = test_input.cuda(), test_target.cuda()
100 mu, std = train_input.data.mean(), train_input.data.std()
101 train_input.data.sub_(mu).div_(std)
102 test_input.data.sub_(mu).div_(std)
105 print('[data generation {:.02f}s]'.format(t2 - t1))
106 model = train_model(train_input, train_target)
109 print('[train {:.02f}s]'.format(t3 - t2))
110 print_test_error(model, test_input, test_target)
114 print('[test {:.02f}s]'.format(t4 - t3))
117 ######################################################################