-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Utilities for partially loading weights from a state_dict #2339
Copy link
Copy link
Closed
Labels
Description
Is your feature request related to a problem? Please describe.
A utility function is needed to set a model's state_dict partially from another one.
this is required by #2273 and #2329
an implementation is already part of the checkpoint loader:
MONAI/monai/handlers/checkpoint_loader.py
Lines 95 to 106 in a536462
| k, _ = list(self.load_dict.items())[0] | |
| # single object and checkpoint is directly a state_dict | |
| if len(self.load_dict) == 1 and k not in checkpoint: | |
| checkpoint = {k: checkpoint} | |
| # skip items that don't match data shape | |
| for k, obj in self.load_dict.items(): | |
| if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |
| obj = obj.module | |
| if isinstance(obj, torch.nn.Module): | |
| d = obj.state_dict() | |
| checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape} |
Reactions are currently unavailable