Update.
[pytorch.git] / flatparam.py
index fbede34..57a8720 100755 (executable)
@@ -24,8 +24,8 @@ def flatparam(model):
     n = sum(p.numel() for p in model.parameters())
     whole = next(model.parameters()).new(n) # Get same device and dtype
     whole.requires_grad_()
-    _flatparam(model, whole, [], 0)
-    return whole
+    _flatparam(model, whole)
+    model.parameters = lambda: iter([ whole ])
 
 ######################################################################
 
@@ -34,7 +34,8 @@ model = nn.Sequential(
     nn.ReLU(),
     nn.Sequential(
         nn.Linear(4, 4),
-        nn.ReLU(), nn.Linear(4, 2)
+        nn.ReLU(),
+        nn.Linear(4, 2)
     )
 )
 
@@ -44,7 +45,7 @@ print('Before:')
 for p in model.parameters():
     print(p.size(), p.storage().size())
 
-whole = flatparam(model)
+flatparam(model)
 
 print('After:')
 for p in model.parameters():
@@ -56,7 +57,7 @@ print('Check:')
 
 input = torch.rand(100, 2)
 targets = torch.rand(100, 2)
-optimizer = torch.optim.SGD([ whole ], lr = 1e-2)
+optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
 mse = nn.MSELoss()
 
 for e in range(10):