+
+import matplotlib.pyplot as plt
+
+
+def save_fig(filename, ymax, ylabel, index):
+ fig = plt.figure()
+ fig.set_figheight(6)
+ fig.set_figwidth(8)
+
+ ax = fig.add_subplot(1, 1, 1)
+
+ ax.set_ylim(0, ymax)
+ ax.spines.right.set_visible(False)
+ ax.spines.top.set_visible(False)
+ ax.set_xscale("log")
+ ax.set_xlabel("Nb hidden units")
+ ax.set_ylabel(ylabel)
+
+ X = torch.tensor([x[0] for x in errors[nn.Linear]])
+ Y = torch.tensor([x[index] for x in errors[nn.Linear]])
+ ax.plot(X, Y, color="gray", label="nn.Linear")
+
+ X = torch.tensor([x[0] for x in errors[QLinear]])
+ Y = torch.tensor([x[index] for x in errors[QLinear]])
+ ax.plot(X, Y, color="red", label="QLinear")
+
+ ax.legend(frameon=False, loc=1)
+
+ print(f"saving {filename}")
+ fig.savefig(filename, bbox_inches="tight")
+
+
+save_fig("bit_mlp_err.pdf", ymax=15, ylabel="Test error (%)", index=1)
+save_fig("bit_mlp_loss.pdf", ymax=1.25, ylabel="Train loss", index=2)