+
+def _flatparam(model, whole, already=[], offset=0):
+ for v in model._parameters:
+ p = model._parameters[v]
+ e = p.numel()
+ s = p.size()
+ model._parameters[v] = whole[offset : offset + e].view(s)
+ with torch.no_grad():
+ model._parameters[v].copy_(p)
+ offset += e
+ already.append(model)
+ for m in model.modules():
+ if m not in already:
+ offset = _flatparam(m, whole, already, offset)
+ return offset
+
+