-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
Is there an existing issue for this?
- I have searched the existing issues
Bug description
Hi DLC team, I’m running this evaluate function:
deeplabcut.evaluate_network(config_path, plotting=True)
But I ran into an error with loading the state_dict on this line:
Line 67: <environment folder>/lib/python3.10/site-packages/deeplabcut/pose_estimation_pytorch/runners/base.py
@staticmethod
def load_snapshot(
snapshot_path: str | Path,
device: str,
model: ModelType,
optimizer: torch.optim.Optimizer | None = None,
) -> int:
"""
Args:
snapshot_path: the path containing the model weights to load
device: the device on which the model should be loaded
model: the model for which the weights are loaded
optimizer: if defined, the optimizer weights to load
Returns:
the number of epochs the model was trained for
"""
snapshot = torch.load(snapshot_path, map_location=device)
model.load_state_dict(snapshot['model'])
if optimizer is not None and 'optimizer' in snapshot:
optimizer.load_state_dict(snapshot["optimizer"])
return snapshot.get("metadata", {}).get("epoch", 0)
right before model.load_state_dict()
model has these state dict keys:
model.state_dict()
OrderedDict([('backbone.model.conv1.weight', torch.Size([64, 3, 7...4e-02]]]])), ('backbone.model.bn1.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.bn1.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv1.weight', torch.Size([64, 64, ....1165]]]])), ('backbone.model.layer...bn1.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.layer...0.bn1.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv2.weight', torch.Size([64, 64, ....0103]]]])), ('backbone.model.layer...bn2.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.layer...0.bn2.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv3.weight', torch.Size([256, 64,....0079]]]])), ('backbone.model.layer...bn3.weight', torch.Size([256])_te..., 0., 0.])), ('backbone.model.layer...0.bn3.bias', torch.Size([256])_te..., 0., 0.])), ('backbone.model.layer...e.0.weight', torch.Size([256, 64,....0011]]]])), ('backbone.model.layer...e.1.weight', torch.Size([256])_te..., 1., 1.])), ...])
But the state dict keys of the snapshot is like this:
snapshot['model'].keys()
odict_keys(['module.backbone.model.conv1.weight', 'module.backbone.model.bn1.weight', 'module.backbone.model.bn1.bias', 'module.backbone.model.layer1.0.conv1.weight', 'module.backbone.model.layer1.0.bn1.weight', 'module.backbone.model.layer1.0.bn1.bias', 'module.backbone.model.layer1.0.conv2.weight', 'module.backbone.model.layer1.0.bn2.weight', 'module.backbone.model.layer1.0.bn2.bias', 'module.backbone.model.layer1.0.conv3.weight', 'module.backbone.model.layer1.0.bn3.weight', 'module.backbone.model.layer1.0.bn3.bias', 'module.backbone.model.layer1.0.downsample.0.weight' …
So the loading of state_dict threw an error because the keys didn’t match.
My theory is that it’s likely due to the model being wrapped with torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel (while training). These modules are used to parallelize model training across multiple GPUs, and when a model is wrapped in DataParallel or DistributedDataParallel, PyTorch automatically adds the "module." prefix to the parameters of the model.
So my fix was to remove the “module.” prefix from the keys of the saved state dicts:
snapshot = torch.load(snapshot_path, map_location=device)
# model.load_state_dict(snapshot['model'])
new_state_dict = {k.replace('module.', ''): v for k, v in snapshot['model'].items()}
model.load_state_dict(new_state_dict)
The model was loaded from the saved snapshot successfully.
Could you please take a look? Thank you!
Operating System
RedHat 9
DeepLabCut version
DLC 3.0.0rc4...
DeepLabCut mode
single animal
Device type
gpu
Steps To Reproduce
Please see above
Relevant log output
No response
Anything else?
No response
Code of Conduct
- I agree to follow this project's Code of Conduct