Skip to content

Utilities for partially loading weights from a state_dict #2339

@wyli

Description

@wyli

Is your feature request related to a problem? Please describe.
A utility function is needed to set a model's state_dict partially from another one.
this is required by #2273 and #2329

an implementation is already part of the checkpoint loader:

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
if len(self.load_dict) == 1 and k not in checkpoint:
checkpoint = {k: checkpoint}
# skip items that don't match data shape
for k, obj in self.load_dict.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
d = obj.state_dict()
checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape}

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions