+def _flatparam(model, whole, already = [], offset = 0):
+ for v in model._parameters:
+ p = model._parameters[v]
+ e = p.numel()
+ s = p.size()
+ model._parameters[v] = whole[offset:offset+e].view(s)
+ with torch.no_grad():
+ model._parameters[v].copy_(p)
+ offset += e
+ already.append(model)
+ for m in model.modules():
+ if m not in already:
+ offset = _flatparam(m, whole, already, offset)
+ return offset
+