Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 17:58:56 +0000 (18:58 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 17:58:56 +0000 (18:58 +0100)
bit_mlp.py

index 8fffe7a..90409f2 100755 (executable)
@@ -9,7 +9,7 @@ import os, sys
 import torch, torchvision
 from torch import nn
 
-lr, nb_epochs, batch_size = 2e-3, 50, 100
+lr, nb_epochs, batch_size = 2e-3, 100, 100
 
 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
 
@@ -57,8 +57,10 @@ class QLinear(nn.Module):
 
 ######################################################################
 
-for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
-    for linear_layer in [nn.Linear, QLinear]:
+errors = {QLinear: [], nn.Linear: []}
+
+for linear_layer in errors.keys():
+    for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
         # The model
 
         model = nn.Sequential(
@@ -114,7 +116,33 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
 
         ######################################################################
 
-        print(
-            f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %"
-        )
-        sys.stdout.flush()
+        errors[linear_layer].append((nb_hidden, test_error))
+
+import matplotlib.pyplot as plt
+
+fig = plt.figure()
+fig.set_figheight(6)
+fig.set_figwidth(8)
+
+ax = fig.add_subplot(1, 1, 1)
+
+ax.set_ylim(0, 1)
+ax.spines.right.set_visible(False)
+ax.spines.top.set_visible(False)
+ax.set_xscale("log")
+ax.set_xlabel("Nb hidden units")
+ax.set_ylabel("Test error (%)")
+
+X = torch.tensor([x[0] for x in errors[nn.Linear]])
+Y = torch.tensor([x[1] for x in errors[nn.Linear]])
+ax.plot(X, Y, color="gray", label="nn.Linear")
+
+X = torch.tensor([x[0] for x in errors[QLinear]])
+Y = torch.tensor([x[1] for x in errors[QLinear]])
+ax.plot(X, Y, color="red", label="QLinear")
+
+ax.legend(frameon=False, loc=1)
+
+filename = f"bit_mlp.pdf"
+print(f"saving {filename}")
+fig.savefig(filename, bbox_inches="tight")