Improved loading of snapshot weights with torch.load(..., weights_only=True)
#2823
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request offers Improved handling of loading snapshot weights with
torch.load(..., weights_only=True). PyTorch snapshots saved with older release candidates could contain somenumpyfloats, which failed to load withweights_only=True, which can make it annoying to use them as thepytorch_config.yamlneeded to be modified withload_weights_only: truekeys for both detectors and pose models. In this pull request: the following improvements are made:Fix the issue with
numpy>=1.25For users with
numpy>=1.25installed, the issue is fixed as thefloat64class causing issues can be added to thesafe_globals, as done in_add_numpy_to_torch_safe_globals. Current snapshots will be loaded without error, as they are withweights_only=False.This doesn't work in
numpy<1.25, as theFloat64Dtypedid not yet exist (it was a dtype that could only be used internally), and there is no easy way to addnp.dtype[np.float64]to the safe globals.A global variable is set to handle the default
weights_onlyvalueThe global variable sets the default value given to
load_weights_only, when none is specified in thepytorch_config.yaml. This value is controlled through theget_load_weights_onlyandset_load_weights_onlymethods, which can be imported throughdeeplabcut.pose_estimation_pytorch:When calling
torch.loadwithoutload_weights_onlybeing specified,get_load_weights_only()is used to get the default value. So when loading snapshots that are known to be safe,set_load_weights_only(False)can be called at the start of a script so that all snapshots are loaded withweights_only=False.So for users using
numpy<1.25with snapshots that have issues, they can load the snapshots without having to modify thepytorch_config.yamlby just callingfrom deeplabcut.pose_estimation_pytorch import set_load_weights_onlyandset_load_weights_only(False)before loading their snapshots.The initial default
load_weights_onlyvalue can also be set with anTORCH_LOAD_WEIGHTS_ONLYenvironment variable, which makes it easier to set this value when working with the GUI. The DeepLabCut GUI can just be launched withTORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut, which will set the defaultload_weights_onlyvalue toFalse.Function to fix snapshots containing numpy float values
A new
fix_snapshot_metadatamethod is added to replace numpy floats with python floats in existing snapshots.Improved error message when a snapshot fails to load
The error message when a snapshot fails to load is improved and more descriptive. It now says: