X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=bit_mlp.py;fp=bit_mlp.py;h=8fffe7a020f9b1118b78c5b0784a12e0a35e8883;hp=85262b72f28f5c6b0cacd6ea8bb86d6d6967cd55;hb=7195d3207fccf4ea38238bdde50399ea344a695f;hpb=53d7745e661073bad93752ea41b0320312250954 diff --git a/bit_mlp.py b/bit_mlp.py index 85262b7..8fffe7a 100755 --- a/bit_mlp.py +++ b/bit_mlp.py @@ -59,6 +59,8 @@ class QLinear(nn.Module): for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: for linear_layer in [nn.Linear, QLinear]: + # The model + model = nn.Sequential( nn.Flatten(), linear_layer(784, nb_hidden), @@ -72,10 +74,9 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: optimizer = torch.optim.Adam(model.parameters(), lr=lr) - ###################################################################### + # for k in range(nb_epochs): - ############################################ # Train model.train() @@ -93,7 +94,6 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: loss.backward() optimizer.step() - ############################################ # Test model.eval()