projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
53d7745
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 26 Mar 2024 17:45:58 +0000
(18:45 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 26 Mar 2024 17:45:58 +0000
(18:45 +0100)
bit_mlp.py
patch
|
blob
|
history
diff --git
a/bit_mlp.py
b/bit_mlp.py
index
85262b7
..
8fffe7a
100755
(executable)
--- a/
bit_mlp.py
+++ b/
bit_mlp.py
@@
-59,6
+59,8
@@
class QLinear(nn.Module):
for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
for linear_layer in [nn.Linear, QLinear]:
for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
for linear_layer in [nn.Linear, QLinear]:
+ # The model
+
model = nn.Sequential(
nn.Flatten(),
linear_layer(784, nb_hidden),
model = nn.Sequential(
nn.Flatten(),
linear_layer(784, nb_hidden),
@@
-72,10
+74,9
@@
for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
- #
#####################################################################
+ #
for k in range(nb_epochs):
for k in range(nb_epochs):
- ############################################
# Train
model.train()
# Train
model.train()
@@
-93,7
+94,6
@@
for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
loss.backward()
optimizer.step()
loss.backward()
optimizer.step()
- ############################################
# Test
model.eval()
# Test
model.eval()