-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix wrapper subclass serialization with custom sizes / strides #137030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137030
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bdd1d61 with merge base 0ccd39a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ides" Fixes #130154 This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic. The PyCapsule issue also affects `deepcopy`, so that's fixed here as well. Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization. [ghstack-poisoned]
…ides" Fixes #130154 This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic. The PyCapsule issue also affects `deepcopy`, so that's fixed here as well. Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization. [ghstack-poisoned]
…ides" Fixes #130154 This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic. The PyCapsule issue also affects `deepcopy`, so that's fixed here as well. Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization. [ghstack-poisoned]
…ides" Fixes #130154 This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic. The PyCapsule issue also affects `deepcopy`, so that's fixed here as well. Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization. [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM !
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #129366 Since NJT has custom serialization logic, we need an NJT-specific fix to clear out cached sizes / strides PyCapsules. Eventually, we should switch NJT to use the default serialization logic, but this depends on #125622 being addressed. This PR also makes serialization more complete by explicitly handling `lengths`, `ragged_idx`, and the `metadata_cache`, ensuring working operation for both contiguous and non-contiguous NJTs, Pull Request resolved: #137031 Approved by: https://github.com/soulitzer ghstack dependencies: #137030
Called out via torchrec integration: `lengths` is not handled properly. Future work (not related to non-contiguous NJTs): #137275 Pull Request resolved: #137124 Approved by: https://github.com/soulitzer ghstack dependencies: #137030, #137031
Stack from ghstack (oldest at bottom):
Fixes #130154
This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic.
The PyCapsule issue also affects
deepcopy, so that's fixed here as well.Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from
_reduce_ex_internal()references the actualtensor.__dict__(), so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization.