90409f2ebdcb7247fa2c1211d78ba2f53c97dda5
[pytorch.git] / bit_mlp.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import os, sys
9 import torch, torchvision
10 from torch import nn
11
12 lr, nb_epochs, batch_size = 2e-3, 100, 100
13
14 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
15
16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
18 ######################################################################
19
20 train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
21 train_input = train_set.data.view(-1, 1, 28, 28).float()
22 train_targets = train_set.targets
23
24 test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
25 test_input = test_set.data.view(-1, 1, 28, 28).float()
26 test_targets = test_set.targets
27
28 train_input, train_targets = train_input.to(device), train_targets.to(device)
29 test_input, test_targets = test_input.to(device), test_targets.to(device)
30
31 mu, std = train_input.mean(), train_input.std()
32
33 train_input.sub_(mu).div_(std)
34 test_input.sub_(mu).div_(std)
35
36 ######################################################################
37
38
39 class QLinear(nn.Module):
40     def __init__(self, dim_in, dim_out):
41         super().__init__()
42         self.w = nn.Parameter(torch.randn(dim_out, dim_in))
43         self.b = nn.Parameter(torch.randn(dim_out) * 1e-1)
44
45     def quantize(self, z):
46         epsilon = 1e-3
47         zr = z / (z.abs().mean() + epsilon)
48         zq = -(zr <= -0.5).long() + (zr >= 0.5).long()
49         if self.training:
50             return zq + z - z.detach()
51         else:
52             return zq.float()
53
54     def forward(self, x):
55         return x @ self.quantize(self.w).t() + self.quantize(self.b)
56
57
58 ######################################################################
59
60 errors = {QLinear: [], nn.Linear: []}
61
62 for linear_layer in errors.keys():
63     for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
64         # The model
65
66         model = nn.Sequential(
67             nn.Flatten(),
68             linear_layer(784, nb_hidden),
69             nn.ReLU(),
70             linear_layer(nb_hidden, 10),
71         ).to(device)
72
73         nb_parameters = sum(p.numel() for p in model.parameters())
74
75         print(f"nb_parameters {nb_parameters}")
76
77         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
78
79         #
80
81         for k in range(nb_epochs):
82             # Train
83
84             model.train()
85
86             acc_train_loss = 0.0
87
88             for input, targets in zip(
89                 train_input.split(batch_size), train_targets.split(batch_size)
90             ):
91                 output = model(input)
92                 loss = torch.nn.functional.cross_entropy(output, targets)
93                 acc_train_loss += loss.item() * input.size(0)
94
95                 optimizer.zero_grad()
96                 loss.backward()
97                 optimizer.step()
98
99             # Test
100
101             model.eval()
102
103             nb_test_errors = 0
104             for input, targets in zip(
105                 test_input.split(batch_size), test_targets.split(batch_size)
106             ):
107                 wta = model(input).argmax(1)
108                 nb_test_errors += (wta != targets).long().sum()
109             test_error = nb_test_errors / test_input.size(0)
110
111             if (k + 1) % 10 == 0:
112                 print(
113                     f"loss {k+1} {acc_train_loss/train_input.size(0)} {test_error*100:.02f}%"
114                 )
115                 sys.stdout.flush()
116
117         ######################################################################
118
119         errors[linear_layer].append((nb_hidden, test_error))
120
121 import matplotlib.pyplot as plt
122
123 fig = plt.figure()
124 fig.set_figheight(6)
125 fig.set_figwidth(8)
126
127 ax = fig.add_subplot(1, 1, 1)
128
129 ax.set_ylim(0, 1)
130 ax.spines.right.set_visible(False)
131 ax.spines.top.set_visible(False)
132 ax.set_xscale("log")
133 ax.set_xlabel("Nb hidden units")
134 ax.set_ylabel("Test error (%)")
135
136 X = torch.tensor([x[0] for x in errors[nn.Linear]])
137 Y = torch.tensor([x[1] for x in errors[nn.Linear]])
138 ax.plot(X, Y, color="gray", label="nn.Linear")
139
140 X = torch.tensor([x[0] for x in errors[QLinear]])
141 Y = torch.tensor([x[1] for x in errors[QLinear]])
142 ax.plot(X, Y, color="red", label="QLinear")
143
144 ax.legend(frameon=False, loc=1)
145
146 filename = f"bit_mlp.pdf"
147 print(f"saving {filename}")
148 fig.savefig(filename, bbox_inches="tight")