Update.
[pytorch.git] / bit_mlp.py
index 85262b7..6f7f92e 100755 (executable)
@@ -9,7 +9,7 @@ import os, sys
 import torch, torchvision
 from torch import nn
 
-lr, nb_epochs, batch_size = 2e-3, 50, 100
+lr, nb_epochs, batch_size = 2e-3, 100, 100
 
 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
 
@@ -57,8 +57,12 @@ class QLinear(nn.Module):
 
 ######################################################################
 
-for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
-    for linear_layer in [nn.Linear, QLinear]:
+errors = {QLinear: [], nn.Linear: []}
+
+for linear_layer in errors.keys():
+    for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
+        # The model
+
         model = nn.Sequential(
             nn.Flatten(),
             linear_layer(784, nb_hidden),
@@ -72,10 +76,9 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
 
         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 
-        ######################################################################
+        #
 
         for k in range(nb_epochs):
-            ############################################
             # Train
 
             model.train()
@@ -93,7 +96,6 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
                 loss.backward()
                 optimizer.step()
 
-            ############################################
             # Test
 
             model.eval()
@@ -114,7 +116,40 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
 
         ######################################################################
 
-        print(
-            f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %"
+        errors[linear_layer].append(
+            (nb_hidden, test_error * 100, acc_train_loss / train_input.size(0))
         )
-        sys.stdout.flush()
+
+import matplotlib.pyplot as plt
+
+
+def save_fig(filename, ymax, ylabel, index):
+    fig = plt.figure()
+    fig.set_figheight(6)
+    fig.set_figwidth(8)
+
+    ax = fig.add_subplot(1, 1, 1)
+
+    ax.set_ylim(0, ymax)
+    ax.spines.right.set_visible(False)
+    ax.spines.top.set_visible(False)
+    ax.set_xscale("log")
+    ax.set_xlabel("Nb hidden units")
+    ax.set_ylabel(ylabel)
+
+    X = torch.tensor([x[0] for x in errors[nn.Linear]])
+    Y = torch.tensor([x[index] for x in errors[nn.Linear]])
+    ax.plot(X, Y, color="gray", label="nn.Linear")
+
+    X = torch.tensor([x[0] for x in errors[QLinear]])
+    Y = torch.tensor([x[index] for x in errors[QLinear]])
+    ax.plot(X, Y, color="red", label="QLinear")
+
+    ax.legend(frameon=False, loc=1)
+
+    print(f"saving {filename}")
+    fig.savefig(filename, bbox_inches="tight")
+
+
+save_fig("bit_mlp_err.pdf", ymax=15, ylabel="Test error (%)", index=1)
+save_fig("bit_mlp_loss.pdf", ymax=1.25, ylabel="Train loss", index=2)