Update.
[pytorch.git] / tinymnist.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_PRE: source ~/venv/pytorch/bin/activate
5
6 # Any copyright is dedicated to the Public Domain.
7 # https://creativecommons.org/publicdomain/zero/1.0/
8
9 # Written by Francois Fleuret <francois@fleuret.org>
10
11 import time, os
12 import torch, torchvision
13 from torch import nn
14 from torch.nn import functional as F
15
16 lr, nb_epochs, batch_size = 1e-1, 10, 100
17
18 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
19
20 if torch.cuda.is_available():
21     device = torch.device("cuda")
22 elif torch.backends.mps.is_available():
23     device = torch.device("mps")
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
30 train_input = train_set.data.view(-1, 1, 28, 28).float()
31 train_targets = train_set.targets
32
33 test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
34 test_input = test_set.data.view(-1, 1, 28, 28).float()
35 test_targets = test_set.targets
36
37 ######################################################################
38
39
40 class SomeLeNet(nn.Module):
41     def __init__(self):
42         super().__init__()
43         self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
44         self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
45         self.fc1 = nn.Linear(256, 200)
46         self.fc2 = nn.Linear(200, 10)
47
48     def forward(self, x):
49         x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
50         x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
51         x = x.view(x.size(0), -1)
52         x = F.relu(self.fc1(x))
53         x = self.fc2(x)
54         return x
55
56
57 ######################################################################
58
59 model = SomeLeNet()
60
61 nb_parameters = sum(p.numel() for p in model.parameters())
62
63 print(f"device {device} nb_parameters {nb_parameters}")
64
65 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
66 criterion = nn.CrossEntropyLoss()
67
68 model.to(device)
69 criterion.to(device)
70
71 train_input, train_targets = train_input.to(device), train_targets.to(device)
72 test_input, test_targets = test_input.to(device), test_targets.to(device)
73
74 mu, std = train_input.mean(), train_input.std()
75 train_input.sub_(mu).div_(std)
76 test_input.sub_(mu).div_(std)
77
78 start_time = time.perf_counter()
79
80 for k in range(nb_epochs):
81     acc_train_loss = 0.0
82
83     for input, targets in zip(
84         train_input.split(batch_size), train_targets.split(batch_size)
85     ):
86         output = model(input)
87         loss = criterion(output, targets)
88         acc_train_loss += loss.item() * input.size(0)
89
90         optimizer.zero_grad()
91         loss.backward()
92         optimizer.step()
93
94     acc_test_loss = 0.0
95     nb_test_errors = 0
96     for input, targets in zip(
97         test_input.split(batch_size), test_targets.split(batch_size)
98     ):
99         output = model(input)
100         loss = criterion(output, targets)
101         acc_test_loss += loss.item() * input.size(0)
102
103         wta = output.argmax(1)
104         nb_test_errors += (wta != targets).long().sum()
105
106     test_error = nb_test_errors / test_input.size(0)
107     duration = time.perf_counter() - start_time
108
109     print(
110         f"loss {k} {duration:.02f}s acc_train_loss {acc_train_loss/train_input.size(0):.02f} test_loss {acc_test_loss/test_input.size(0):.02f} test_error {test_error*100:.02f}%"
111     )
112
113 ######################################################################