X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=flatparam.py;h=57a872038408969605aa971af59d41d2ad2257c5;hp=fbede343fd957f054087cb563cec89b2396aea05;hb=HEAD;hpb=a1c89c4da439a4ad48d8f79b6697a2108be4b514 diff --git a/flatparam.py b/flatparam.py index fbede34..0b61cf1 100755 --- a/flatparam.py +++ b/flatparam.py @@ -5,12 +5,13 @@ from torch import nn ###################################################################### -def _flatparam(model, whole, already = [], offset = 0): + +def _flatparam(model, whole, already=[], offset=0): for v in model._parameters: p = model._parameters[v] e = p.numel() s = p.size() - model._parameters[v] = whole[offset:offset+e].view(s) + model._parameters[v] = whole[offset : offset + e].view(s) with torch.no_grad(): model._parameters[v].copy_(p) offset += e @@ -20,43 +21,42 @@ def _flatparam(model, whole, already = [], offset = 0): offset = _flatparam(m, whole, already, offset) return offset + def flatparam(model): n = sum(p.numel() for p in model.parameters()) - whole = next(model.parameters()).new(n) # Get same device and dtype + whole = next(model.parameters()).new(n) # Get same device and dtype whole.requires_grad_() - _flatparam(model, whole, [], 0) - return whole + _flatparam(model, whole) + model.parameters = lambda: iter([whole]) + ###################################################################### model = nn.Sequential( nn.Linear(2, 4), nn.ReLU(), - nn.Sequential( - nn.Linear(4, 4), - nn.ReLU(), nn.Linear(4, 2) - ) + nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)), ) ###################################################################### -print('Before:') +print("Before:") for p in model.parameters(): print(p.size(), p.storage().size()) -whole = flatparam(model) +flatparam(model) -print('After:') +print("After:") for p in model.parameters(): print(p.size(), p.storage().size()) ###################################################################### -print('Check:') +print("Check:") input = torch.rand(100, 2) targets = torch.rand(100, 2) -optimizer = torch.optim.SGD([ whole ], lr = 1e-2) +optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) mse = nn.MSELoss() for e in range(10):