projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
1660da6
)
Update.
master
author
François Fleuret
<francois@fleuret.org>
Tue, 26 Mar 2024 18:55:44 +0000
(19:55 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 26 Mar 2024 18:55:44 +0000
(19:55 +0100)
bit_mlp.py
patch
|
blob
|
history
diff --git
a/bit_mlp.py
b/bit_mlp.py
index
90409f2
..
6f7f92e
100755
(executable)
--- a/
bit_mlp.py
+++ b/
bit_mlp.py
@@
-116,33
+116,40
@@
for linear_layer in errors.keys():
######################################################################
######################################################################
- errors[linear_layer].append((nb_hidden, test_error))
+ errors[linear_layer].append(
+ (nb_hidden, test_error * 100, acc_train_loss / train_input.size(0))
+ )
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
-fig = plt.figure()
-fig.set_figheight(6)
-fig.set_figwidth(8)
-ax = fig.add_subplot(1, 1, 1)
+def save_fig(filename, ymax, ylabel, index):
+ fig = plt.figure()
+ fig.set_figheight(6)
+ fig.set_figwidth(8)
-ax.set_ylim(0, 1)
-ax.spines.right.set_visible(False)
-ax.spines.top.set_visible(False)
-ax.set_xscale("log")
-ax.set_xlabel("Nb hidden units")
-ax.set_ylabel("Test error (%)")
+ ax = fig.add_subplot(1, 1, 1)
-X = torch.tensor([x[0] for x in errors[nn.Linear]])
-Y = torch.tensor([x[1] for x in errors[nn.Linear]])
-ax.plot(X, Y, color="gray", label="nn.Linear")
+ ax.set_ylim(0, ymax)
+ ax.spines.right.set_visible(False)
+ ax.spines.top.set_visible(False)
+ ax.set_xscale("log")
+ ax.set_xlabel("Nb hidden units")
+ ax.set_ylabel(ylabel)
-
X = torch.tensor([x[0] for x in errors[Q
Linear]])
-
Y = torch.tensor([x[1] for x in errors[Q
Linear]])
-
ax.plot(X, Y, color="red", label="Q
Linear")
+
X = torch.tensor([x[0] for x in errors[nn.
Linear]])
+
Y = torch.tensor([x[index] for x in errors[nn.
Linear]])
+
ax.plot(X, Y, color="gray", label="nn.
Linear")
-ax.legend(frameon=False, loc=1)
+ X = torch.tensor([x[0] for x in errors[QLinear]])
+ Y = torch.tensor([x[index] for x in errors[QLinear]])
+ ax.plot(X, Y, color="red", label="QLinear")
-filename = f"bit_mlp.pdf"
-print(f"saving {filename}")
-fig.savefig(filename, bbox_inches="tight")
+ ax.legend(frameon=False, loc=1)
+
+ print(f"saving {filename}")
+ fig.savefig(filename, bbox_inches="tight")
+
+
+save_fig("bit_mlp_err.pdf", ymax=15, ylabel="Test error (%)", index=1)
+save_fig("bit_mlp_loss.pdf", ymax=1.25, ylabel="Train loss", index=2)