Skip to content

Conversation

@dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Nov 6, 2024

Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still call torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036

CC: @muellerzr @SunMarc

@dvrogozh
Copy link
Contributor Author

@muellerzr, @SunMarc, @ArthurZucker : can you, please, help comment on this PR? see issue #34631 on details.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thanks for adding this ! Left a comment

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2024

I am getting

FAILED tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training - AttributeError: module 'numpy' has no attribute 'dtypes'. Did you mean: 'dtype'?

when running

python3 -m pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training

against this PR.

@dvrogozh
Copy link
Contributor Author

@ydshieh : this might be due to numpy version. dtypes was added in 1.25 according to https://numpy.org/doc/2.1/reference/routines.dtypes.html#module-numpy.dtypes. Locally I have 1.26.4. Which version do you have?

I will work on using context manager since there is an alignment on that and also tune a list per versioning of numpy.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2024

On our CI runner , I get numpy=1.24.3

@mikaylagawarecki
Copy link

The numpy GLOBALs for dtypes that need to be allowlisted might need an if statement depending on whether version < 1.25 or not, there's some documentation on this here https://pytorch.org/docs/main/notes/serialization.html#troubleshooting-weights-only

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @muellerzr if you can have a look as well!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a SAFE_TRANSFORMERS_GLOBAL with these no? this way people can easily update them?
TBH I prefer the context manager but want to have the least duplication as possible!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that calling torch.serialization.add_safe_globals() still works to add additional safe global staff. SAFE_TRANSFORMERS_GLOBAL can also be considered. Let me know if you see the need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add any other numpy dtypes in the list? As of now I spotted only np.unit32 in the Transformers list as the one needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one I don't see from accelerate is encode, however if things pass here without it it's accelerate specific and we don't need to worry about it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transformer tests did pass on my side without adding encode. This indeed seems accelerate specific.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a documentation suggestion but this all looks correct

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dvrogozh
Copy link
Contributor Author

Thanks! Just a documentation suggestion but this all looks correct

@muellerz : done, added a link to Accelerate PR.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a nit

@dvrogozh
Copy link
Contributor Author

LGTM ! Just a nit

@SunMarc : addressed, reused approach from accelerate on numpy.core deprecation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nit: should it be "2.6.0" here or it's really necessary being "2.4.0"?

Copy link
Contributor Author

@dvrogozh dvrogozh Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to version < 2.6.0a0. Indeed, on switching to context manager I overlooked that it was introduced later. Overall:

  • torch.serialization.add_safe_globals appeared in pytorch 2.4
  • torch.serialization.safe_globals (context manager) appeared in 2.5
  • And pytorch 2.6 flipped default of weights_only in torch.load from False to True

Overall, it indeed does not make sense to have this code working for versions earlier than 2.6 unless we will start calling torch.load with explicit weights_only=True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! A tiny question: how to get 2.6.0a0 installed. I know how to install night but it gets dev202411xx instead of a0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, good to use a0 here for now. Once 2.6 is released, we can change it to 2.6.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! A tiny question: how to get 2.6.0a0 installed.

I am getting this building from sources. And <2.6.0 does not work for me on my build. So, 2.6.0a0 is my best effort to get the check working for my current build. I did not know that nightly builds get dev202411xx, I thought they also give a0. I wonder will the check still work for nightly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked. <2.6.0a0 won't work with nightly. So, I switched to a check I ones spotted in a code by Narsil. This should handle both cases, building from sources and using 2.6 nightly (I checked - works for both on my side):

if version.parse(torch.__version__).release < version.parse("2.6").release:

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@ArthurZucker ArthurZucker merged commit 1339a14 into huggingface:main Nov 25, 2024
24 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for fixing 🤗

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036

Signed-off-by: Dmitry Rogozhkin <[email protected]>
michael-kuhlmann added a commit to michael-kuhlmann/padertorch that referenced this pull request Jan 31, 2025
torch restricted the unpickler to work with torch.Tensors and a few primitive types (https://pytorch.org/docs/stable/notes/serialization.html#weights-only).
To work with other types, one can either set weights_only=False (which is unsafe)
or add a whitelist of additional types that may be unpickled.
See also discussion in huggingface/transformers#34632 where the solution was adopted from.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

safe_globals are needed to resume training on upcoming PyTorch 2.6

7 participants