3 import torch, torchvision
6 ######################################################################
8 def _flatparam(model, whole, already = [], offset = 0):
9 for v in model._parameters:
10 p = model._parameters[v]
13 model._parameters[v] = whole[offset:offset+e].view(s)
15 model._parameters[v].copy_(p)
18 for m in model.modules():
20 offset = _flatparam(m, whole, already, offset)
24 n = sum(p.numel() for p in model.parameters())
25 whole = next(model.parameters()).new(n) # Get same device and dtype
26 whole.requires_grad_()
27 _flatparam(model, whole)
28 model.parameters = lambda: iter([ whole ])
30 ######################################################################
32 model = nn.Sequential(
37 nn.ReLU(), nn.Linear(4, 2)
41 ######################################################################
44 for p in model.parameters():
45 print(p.size(), p.storage().size())
50 for p in model.parameters():
51 print(p.size(), p.storage().size())
53 ######################################################################
57 input = torch.rand(100, 2)
58 targets = torch.rand(100, 2)
59 optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
64 loss = mse(output, targets)