X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=flatparam.py;h=57a872038408969605aa971af59d41d2ad2257c5;hp=3c20153fd6dc9b2dc12b9238f8a2978fe85662a4;hb=HEAD;hpb=7935bad172ac27fa77d28dc8bf7147f5b5aabaaa diff --git a/flatparam.py b/flatparam.py index 3c20153..0b61cf1 100755 --- a/flatparam.py +++ b/flatparam.py @@ -5,39 +5,58 @@ from torch import nn ###################################################################### + +def _flatparam(model, whole, already=[], offset=0): + for v in model._parameters: + p = model._parameters[v] + e = p.numel() + s = p.size() + model._parameters[v] = whole[offset : offset + e].view(s) + with torch.no_grad(): + model._parameters[v].copy_(p) + offset += e + already.append(model) + for m in model.modules(): + if m not in already: + offset = _flatparam(m, whole, already, offset) + return offset + + def flatparam(model): - with torch.no_grad(): - n = sum(p.numel() for p in model.parameters()) - big = next(model.parameters()).new(n) # Get same device and dtype - k = 0 - for p in model.parameters(): - tmp = p.new(0).set_(p) - p.set_(big.storage(), k, p.size()).copy_(tmp) - k += p.numel() + n = sum(p.numel() for p in model.parameters()) + whole = next(model.parameters()).new(n) # Get same device and dtype + whole.requires_grad_() + _flatparam(model, whole) + model.parameters = lambda: iter([whole]) + ###################################################################### model = nn.Sequential( - nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2) + nn.Linear(2, 4), + nn.ReLU(), + nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 2)), ) -print('Before:') +###################################################################### + +print("Before:") for p in model.parameters(): print(p.size(), p.storage().size()) flatparam(model) -print('After:') +print("After:") for p in model.parameters(): print(p.size(), p.storage().size()) ###################################################################### -print('Check:') +print("Check:") input = torch.rand(100, 2) targets = torch.rand(100, 2) -optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2) +optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) mse = nn.MSELoss() for e in range(10):