Skip to content

Conversation

@MMathisLab
Copy link
Member

@MMathisLab MMathisLab commented Dec 23, 2024

This pull request offers Improved handling of loading snapshot weights with torch.load(..., weights_only=True). PyTorch snapshots saved with older release candidates could contain some numpy floats, which failed to load with weights_only=True, which can make it annoying to use them as the pytorch_config.yaml needed to be modified with load_weights_only: true keys for both detectors and pose models. In this pull request: the following improvements are made:

Fix the issue with numpy>=1.25

For users with numpy>=1.25 installed, the issue is fixed as the float64 class causing issues can be added to the safe_globals, as done in _add_numpy_to_torch_safe_globals. Current snapshots will be loaded without error, as they are with weights_only=False.

torch.serialization.add_safe_globals([np.dtype, Float64DType, scalar])

This doesn't work in numpy<1.25, as the Float64Dtype did not yet exist (it was a dtype that could only be used internally), and there is no easy way to add np.dtype[np.float64] to the safe globals.

A global variable is set to handle the default weights_only value

The global variable sets the default value given to load_weights_only, when none is specified in the pytorch_config.yaml. This value is controlled through the get_load_weights_only and set_load_weights_only methods, which can be imported through deeplabcut.pose_estimation_pytorch:

>>> from deeplabcut.pose_estimation_pytorch import get_load_weights_only, set_load_weights_only
>>> print(get_load_weights_only())
True
>>> set_load_weights_only(False)
>>> print(get_load_weights_only())
False

When calling torch.load without load_weights_only being specified, get_load_weights_only() is used to get the default value. So when loading snapshots that are known to be safe, set_load_weights_only(False) can be called at the start of a script so that all snapshots are loaded with weights_only=False.

So for users using numpy<1.25 with snapshots that have issues, they can load the snapshots without having to modify the pytorch_config.yaml by just calling from deeplabcut.pose_estimation_pytorch import set_load_weights_only and set_load_weights_only(False) before loading their snapshots.

The initial default load_weights_only value can also be set with an TORCH_LOAD_WEIGHTS_ONLY environment variable, which makes it easier to set this value when working with the GUI. The DeepLabCut GUI can just be launched with TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut, which will set the default load_weights_only value to False.

Function to fix snapshots containing numpy float values

A new fix_snapshot_metadata method is added to replace numpy floats with python floats in existing snapshots.

Improved error message when a snapshot fails to load

The error message when a snapshot fails to load is improved and more descriptive. It now says:

ERROR:root:
Failed to load the snapshot: snapshot-best-200.pt.

If you trust the snapshot that you're trying to load, you can try
calling `Runner.load_snapshot` with `weights_only=False`. See the 
error message below for more information and warnings.
You can set the `weights_only` parameter in the model configuration (
the content of the pytorch_config.yaml), as:

'''
runner:
  load_weights_only: False
'''

If it's the detector snapshot that's failing to load, place the
`load_weights_only` key under the detector runner:

'''
detector:
    runner:
      load_weights_only: False
'''

You can also set the default `load_weights_only` that will be used when
the `load_weights_only` variable is not set in the `pytorch_config.yaml`
using `deeplabcut.pose_estimation_pytorch.set_load_weights_only(value)`:

'''
from deeplabcut.pose_estimation_pytorch import set_load_weights_only
set_load_weights_only(True)
'''

You can also set the value for `load_weights_only` with a 
`TORCH_LOAD_WEIGHTS_ONLY` environment variable. If you call 
`TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut`, it will launch the
DeepLabCut GUI with the default `load_weights_only` value to False.
If you set this value to `False`, make sure you only load snapshots that
you trust.

@MMathisLab MMathisLab merged commit 1d6ce0d into pytorch_dlc Dec 23, 2024
1 check passed
@MMathisLab MMathisLab deleted the niels/improved_weight_only_handling branch December 23, 2024 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants