Skip to content

Conversation

@n-poulsen
Copy link
Contributor

Addresses #2725. The torch.load method is now called with weights_only=True as a default. If loading snapshot weights fails, users can set a load_weights_only: False key in their pytorch_config.yaml file to set the weights_only parameter to False. This is explained in the error message when loading weights fails. This should only be done when users trust the content of the snapshot they're loading, as explained in the torch.load docs.

Tests were added to ensure torch.load was called with the correct parameters.

To set weights_only=False, users can edit their pytorch_config.yaml files as:

detector:
    ...
    runner:
        load_weights_only: false  # to load detector weights with `weights_only=False`

runner:
    load_weights_only: false  # to load pose model weights with `weights_only=False`

@n-poulsen n-poulsen added enhancement New feature or request DLC3.0🔥 labels Dec 18, 2024
@n-poulsen n-poulsen merged commit fd6916b into pytorch_dlc Dec 19, 2024
1 check passed
@n-poulsen n-poulsen deleted the niels/snapshot_weight_load branch December 19, 2024 12:54
@MMathisLab
Copy link
Member

MMathisLab commented Dec 19, 2024

this seems to break this fucntion now @n-poulsen

superanimal_analyze_images(superanimal_name,
                           model_name,
                           detector_name= 'fasterrcnn_mobilenet_v3_large_fpn',
                           images = in_image_folder,
                           max_individuals = max_individuals,
                           out_folder = out_image_folder,
                           pose_threshold = 0.5,
                           bbox_threshold = 0.6,
                           plot_skeleton = True,
                           customized_model_config=customized_model_config,
                           customized_pose_checkpoint=customized_pose_checkpoint,
                           customized_detector_checkpoint=customized_detector_checkpoint)

output (on weights that worked yesterday!)

Failed to load the snapshot: /content/drive/MyDrive/.../snapshot-detector-best-070.pt.
If you trust the snapshot that you're trying to load, you can try calling `Runner.load_snapshot` with `weights_only=False`. See the message below for more information and warnings.
You can set the `weights_only` parameter in the model configuration (the content of the pytorch_config.yaml), as:

runner:
  load_weights_only: False


---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
[<ipython-input-15-12fba7293495>](https://localhost:8080/#) in <cell line: 1>()
----> 1 superanimal_analyze_images(superanimal_name,
      2                            model_name,
      3                            detector_name= 'fasterrcnn_mobilenet_v3_large_fpn',
      4                            images = in_image_folder,
      5                            max_individuals = max_individuals,

9 frames
[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/apis/analyze_images.py](https://localhost:8080/#) in superanimal_analyze_images(superanimal_name, model_name, detector_name, images, max_individuals, out_folder, progress_bar, device, pose_threshold, bbox_threshold, plot_skeleton, customized_model_config, customized_pose_checkpoint, customized_detector_checkpoint)
    181         config["detector"]["model"]["box_score_thresh"] = bbox_threshold
    182 
--> 183     predictions = analyze_image_folder(
    184         model_cfg=config,
    185         images=images,

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/apis/analyze_images.py](https://localhost:8080/#) in analyze_image_folder(model_cfg, images, snapshot_path, detector_path, frame_type, device, max_individuals, progress_bar)
    435         device = resolve_device(model_cfg)
    436 
--> 437     pose_runner, detector_runner = get_inference_runners(
    438         model_config=model_cfg,
    439         snapshot_path=snapshot_path,

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/apis/utils.py](https://localhost:8080/#) in get_inference_runners(model_config, snapshot_path, max_individuals, num_bodyparts, num_unique_bodyparts, batch_size, device, with_identity, transform, detector_batch_size, detector_path, detector_transform, dynamic)
    521                 detector_config["pretrained"] = False
    522 
--> 523             detector_runner = build_inference_runner(
    524                 task=Task.DETECT,
    525                 model=DETECTORS.build(detector_config),

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/inference.py](https://localhost:8080/#) in build_inference_runner(task, model, device, snapshot_path, batch_size, preprocessor, postprocessor, dynamic, load_weights_only)
    370                 f"detection. Please turn off dynamic cropping."
    371             )
--> 372         return DetectorInferenceRunner(**kwargs)
    373 
    374     if task != Task.BOTTOM_UP:

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/inference.py](https://localhost:8080/#) in __init__(self, model, **kwargs)
    291             **kwargs: Inference runner kwargs.
    292         """
--> 293         super().__init__(model, **kwargs)
    294 
    295     def predict(self, inputs: torch.Tensor) -> list[dict[str, dict[str, np.ndarray]]]:

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/inference.py](https://localhost:8080/#) in __init__(self, model, batch_size, device, snapshot_path, preprocessor, postprocessor, load_weights_only)
     68 
     69         if self.snapshot_path is not None and self.snapshot_path != "":
---> 70             self.load_snapshot(
     71                 self.snapshot_path,
     72                 self.device,

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/base.py](https://localhost:8080/#) in load_snapshot(snapshot_path, device, model, weights_only)
     84             The content of the snapshot file.
     85         """
---> 86         snapshot = attempt_snapshot_load(snapshot_path, device, weights_only)
     87         model.load_state_dict(snapshot["model"])
     88         return snapshot

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/base.py](https://localhost:8080/#) in attempt_snapshot_load(path, device, weights_only)
    121             "  load_weights_only: False\n```\n"
    122         )
--> 123         raise err
    124 
    125     return snapshot

[/usr/local/lib/python3.10/dist-packages/deeplabcut/pose_estimation_pytorch/runners/base.py](https://localhost:8080/#) in attempt_snapshot_load(path, device, weights_only)
    109     """
    110     try:
--> 111         snapshot = torch.load(path, map_location=device, weights_only=weights_only)
    112     except pickle.UnpicklingError as err:
    113         print(

[/usr/local/lib/python3.10/dist-packages/torch/serialization.py](https://localhost:8080/#) in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1357                         )
   1358                     except pickle.UnpicklingError as e:
-> 1359                         raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
   1360                 return _load(
   1361                     opened_zipfile,

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

and if I edit the pytorch config, adding load_weights_only: False does not solve it.

@MMathisLab
Copy link
Member

@n-poulsen I can confirm if I roll back to this commit, it's working, so I think its this PR that is causing issues cc @maximpavliv !pip install --force-reinstall "git+https://github.com/DeepLabCut/DeepLabCut.git@198fc74#egg=deeplabcut[modelzoo]"

@n-poulsen
Copy link
Contributor Author

@MMathisLab Sorry about that - I should have made the error message more clear. This might happen on weights that were previously trained in DeepLabCut. The issue is with numpy float64 values being stored in the metadata for snapshots (some metric functions returned numpy floats instead of native python floats). The metric functions now all return native python floats, so this should not be an issue in the future.

If we don't want users who already have trained weights to have these issues, we could:

  • Temporarily set the load_weights_only parameter to true as default. This is the current behavior with PyTorch if weights_only is not explicitly set to False
  • Add a helper function to "fix" snapshots, converting all values to python floats instead of numpy floats

Do you see one of those two options as a good way to deal with this issue?

I was able to successfully load such weights by adding the load_weights_only: false parameter. Here it seems that it's the detector snapshot that's failing to load, so the load_weights_only key should be added under the detector's runner in the pytorch_config.yaml file. I'll update the error message to make it clear that the load_weights_only key should be added for the detector or pose snapshot specifically

...
detector:
  runner:
    load_weights_only: false
...
runner:
  load_weights_only: false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

DLC3.0🔥 enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants