Skip to content

Could not load a saved snapshot due to mismatched state_dict keys #2749

@thuann2cats

Description

@thuann2cats

Is there an existing issue for this?

  • I have searched the existing issues

Bug description

Hi DLC team, I’m running this evaluate function:
deeplabcut.evaluate_network(config_path, plotting=True)

But I ran into an error with loading the state_dict on this line:
Line 67: <environment folder>/lib/python3.10/site-packages/deeplabcut/pose_estimation_pytorch/runners/base.py

@staticmethod
    def load_snapshot(
        snapshot_path: str | Path,
        device: str,
        model: ModelType,
        optimizer: torch.optim.Optimizer | None = None,
    ) -> int:
        """
        Args:
            snapshot_path: the path containing the model weights to load
            device: the device on which the model should be loaded
            model: the model for which the weights are loaded
            optimizer: if defined, the optimizer weights to load

        Returns:
            the number of epochs the model was trained for
        """
        snapshot = torch.load(snapshot_path, map_location=device)
        model.load_state_dict(snapshot['model'])
        if optimizer is not None and 'optimizer' in snapshot:
            optimizer.load_state_dict(snapshot["optimizer"])

        return snapshot.get("metadata", {}).get("epoch", 0)

right before model.load_state_dict()

model has these state dict keys:

model.state_dict()
OrderedDict([('backbone.model.conv1.weight', torch.Size([64, 3, 7...4e-02]]]])), ('backbone.model.bn1.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.bn1.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv1.weight', torch.Size([64, 64, ....1165]]]])), ('backbone.model.layer...bn1.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.layer...0.bn1.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv2.weight', torch.Size([64, 64, ....0103]]]])), ('backbone.model.layer...bn2.weight', torch.Size([64])_ten..., 1., 1.])), ('backbone.model.layer...0.bn2.bias', torch.Size([64])_ten..., 0., 0.])), ('backbone.model.layer...nv3.weight', torch.Size([256, 64,....0079]]]])), ('backbone.model.layer...bn3.weight', torch.Size([256])_te..., 0., 0.])), ('backbone.model.layer...0.bn3.bias', torch.Size([256])_te..., 0., 0.])), ('backbone.model.layer...e.0.weight', torch.Size([256, 64,....0011]]]])), ('backbone.model.layer...e.1.weight', torch.Size([256])_te..., 1., 1.])), ...])

But the state dict keys of the snapshot is like this:

snapshot['model'].keys()
odict_keys(['module.backbone.model.conv1.weight', 'module.backbone.model.bn1.weight', 'module.backbone.model.bn1.bias', 'module.backbone.model.layer1.0.conv1.weight', 'module.backbone.model.layer1.0.bn1.weight', 'module.backbone.model.layer1.0.bn1.bias', 'module.backbone.model.layer1.0.conv2.weight', 'module.backbone.model.layer1.0.bn2.weight', 'module.backbone.model.layer1.0.bn2.bias', 'module.backbone.model.layer1.0.conv3.weight', 'module.backbone.model.layer1.0.bn3.weight', 'module.backbone.model.layer1.0.bn3.bias', 'module.backbone.model.layer1.0.downsample.0.weight' …

So the loading of state_dict threw an error because the keys didn’t match.

My theory is that it’s likely due to the model being wrapped with torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel (while training). These modules are used to parallelize model training across multiple GPUs, and when a model is wrapped in DataParallel or DistributedDataParallel, PyTorch automatically adds the "module." prefix to the parameters of the model.

So my fix was to remove the “module.” prefix from the keys of the saved state dicts:

snapshot = torch.load(snapshot_path, map_location=device)
# model.load_state_dict(snapshot['model'])
new_state_dict = {k.replace('module.', ''): v for k, v in snapshot['model'].items()}
model.load_state_dict(new_state_dict)

The model was loaded from the saved snapshot successfully.
Could you please take a look? Thank you!

Operating System

RedHat 9

DeepLabCut version

DLC 3.0.0rc4...

DeepLabCut mode

single animal

Device type

gpu

Steps To Reproduce

Please see above

Relevant log output

No response

Anything else?

No response

Code of Conduct

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions