Initial commit.
[pytorch.git] / tinymnist.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import time, os
9 import torch, torchvision
10 from torch import nn
11 from torch.nn import functional as F
12
13 lr, nb_epochs, batch_size = 1e-1, 10, 100
14
15 data_dir = os.environ.get('PYTORCH_DATA_DIR') or './data/'
16
17 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
19 ######################################################################
20
21 train_set = torchvision.datasets.MNIST(root = data_dir, train = True, download = True)
22 train_input = train_set.data.view(-1, 1, 28, 28).float()
23 train_targets = train_set.targets
24
25 test_set = torchvision.datasets.MNIST(root = data_dir, train = False, download = True)
26 test_input = test_set.data.view(-1, 1, 28, 28).float()
27 test_targets = test_set.targets
28
29 ######################################################################
30
31 class SomeLeNet(nn.Module):
32     def __init__(self):
33         super().__init__()
34         self.conv1 = nn.Conv2d(1, 32, kernel_size = 5)
35         self.conv2 = nn.Conv2d(32, 64, kernel_size = 5)
36         self.fc1 = nn.Linear(256, 200)
37         self.fc2 = nn.Linear(200, 10)
38
39     def forward(self, x):
40         x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3))
41         x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2))
42         x = x.view(x.size(0), -1)
43         x = F.relu(self.fc1(x))
44         x = self.fc2(x)
45         return x
46
47 ######################################################################
48
49 model = SomeLeNet()
50
51 nb_parameters = sum(p.numel() for p in model.parameters())
52
53 print(f'nb_parameters {nb_parameters}')
54
55 optimizer = torch.optim.SGD(model.parameters(), lr = lr)
56 criterion = nn.CrossEntropyLoss()
57
58 model.to(device)
59 criterion.to(device)
60
61 train_input, train_targets = train_input.to(device), train_targets.to(device)
62 test_input, test_targets = test_input.to(device), test_targets.to(device)
63
64 mu, std = train_input.mean(), train_input.std()
65 train_input.sub_(mu).div_(std)
66 test_input.sub_(mu).div_(std)
67
68 start_time = time.perf_counter()
69
70 for k in range(nb_epochs):
71     acc_loss = 0.
72
73     for input, targets in zip(train_input.split(batch_size),
74                               train_targets.split(batch_size)):
75         output = model(input)
76         loss = criterion(output, targets)
77         acc_loss += loss.item()
78
79         optimizer.zero_grad()
80         loss.backward()
81         optimizer.step()
82
83     nb_test_errors = 0
84     for input, targets in zip(test_input.split(batch_size),
85                               test_targets.split(batch_size)):
86         wta = model(input).argmax(1)
87         nb_test_errors += (wta != targets).long().sum()
88     test_error = nb_test_errors / test_input.size(0)
89     duration = time.perf_counter() - start_time
90
91     print(f'loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%')
92
93 ######################################################################