-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Add safe_globals to resume training on PyTorch 2.6 #34632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2cee855 to
fa62472
Compare
|
@muellerzr, @SunMarc, @ArthurZucker : can you, please, help comment on this PR? see issue #34631 on details. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thanks for adding this ! Left a comment
|
I am getting
when running
against this PR. |
|
@ydshieh : this might be due to numpy version. dtypes was added in 1.25 according to https://numpy.org/doc/2.1/reference/routines.dtypes.html#module-numpy.dtypes. Locally I have 1.26.4. Which version do you have? I will work on using context manager since there is an alignment on that and also tune a list per versioning of numpy. |
|
On our CI runner , I get |
|
The numpy GLOBALs for dtypes that need to be allowlisted might need an if statement depending on whether version < 1.25 or not, there's some documentation on this here https://pytorch.org/docs/main/notes/serialization.html#troubleshooting-weights-only |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @muellerzr if you can have a look as well!
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have a SAFE_TRANSFORMERS_GLOBAL with these no? this way people can easily update them?
TBH I prefer the context manager but want to have the least duplication as possible!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that calling torch.serialization.add_safe_globals() still works to add additional safe global staff. SAFE_TRANSFORMERS_GLOBAL can also be considered. Let me know if you see the need.
fa62472 to
276a3a0
Compare
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I add any other numpy dtypes in the list? As of now I spotted only np.unit32 in the Transformers list as the one needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only one I don't see from accelerate is encode, however if things pass here without it it's accelerate specific and we don't need to worry about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transformer tests did pass on my side without adding encode. This indeed seems accelerate specific.
276a3a0 to
4273a30
Compare
muellerzr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a documentation suggestion but this all looks correct
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
4273a30 to
468aa06
Compare
@muellerz : done, added a link to Accelerate PR. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Just a nit
468aa06 to
27c307f
Compare
@SunMarc : addressed, reused approach from accelerate on |
27c307f to
dbb3112
Compare
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a nit: should it be "2.6.0" here or it's really necessary being "2.4.0"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to version < 2.6.0a0. Indeed, on switching to context manager I overlooked that it was introduced later. Overall:
torch.serialization.add_safe_globalsappeared in pytorch 2.4torch.serialization.safe_globals(context manager) appeared in 2.5- And pytorch 2.6 flipped default of
weights_onlyintorch.loadfromFalsetoTrue
Overall, it indeed does not make sense to have this code working for versions earlier than 2.6 unless we will start calling torch.load with explicit weights_only=True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! A tiny question: how to get 2.6.0a0 installed. I know how to install night but it gets dev202411xx instead of a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, good to use a0 here for now. Once 2.6 is released, we can change it to 2.6.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! A tiny question: how to get 2.6.0a0 installed.
I am getting this building from sources. And <2.6.0 does not work for me on my build. So, 2.6.0a0 is my best effort to get the check working for my current build. I did not know that nightly builds get dev202411xx, I thought they also give a0. I wonder will the check still work for nightly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked. <2.6.0a0 won't work with nightly. So, I switched to a check I ones spotted in a code by Narsil. This should handle both cases, building from sources and using 2.6 nightly (I checked - works for both on my side):
if version.parse(torch.__version__).release < version.parse("2.6").release:
dbb3112 to
0505f2c
Compare
ydshieh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects. This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still additionally call torch.serialization.add_safe_globals() to add other objects into the safe globals list. Accelerate library also stepped into same problem and addressed it with PR-3036. Fixes: huggingface#34631 See: pytorch/pytorch#137602 See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals See: huggingface/accelerate#3036 Signed-off-by: Dmitry Rogozhkin <[email protected]>
0505f2c to
820ca4a
Compare
|
Thanks for fixing 🤗 |
Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects. This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still additionally call torch.serialization.add_safe_globals() to add other objects into the safe globals list. Accelerate library also stepped into same problem and addressed it with PR-3036. Fixes: huggingface#34631 See: pytorch/pytorch#137602 See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals See: huggingface/accelerate#3036 Signed-off-by: Dmitry Rogozhkin <[email protected]>
torch restricted the unpickler to work with torch.Tensors and a few primitive types (https://pytorch.org/docs/stable/notes/serialization.html#weights-only). To work with other types, one can either set weights_only=False (which is unsafe) or add a whitelist of additional types that may be unpickled. See also discussion in huggingface/transformers#34632 where the solution was adopted from.
Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with
torch.load(). Starting from version 2.6 loading withweights_only=Truerequires allowlisting of such objects.This commit adds allowlist of some
numpyobjects used to load model checkpoints. Usage is restricted by context manager. User can still calltorch.serialization.add_safe_globals()to add other objects into the safe globals list.Accelerate library also stepped into same problem and addressed it with PR-3036.
Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
CC: @muellerzr @SunMarc