-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference #127313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference #127313
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127313
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (2 Unrelated Failures)As of commit 6e4efb1 with merge base 5196ef1 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just some more bikeshedding on the error message
torch/utils/__init__.py
Outdated
| raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors " | ||
| "this can happen when you try to run backward on a tensor that was swapped. " | ||
| "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` " | ||
| "this could happen if trying to run backward changing the device or dtype of the module " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "this could happen if trying to run backward changing the device or dtype of the module " | |
| "this could happen if trying to run backward after changing the device or dtype of the module " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually.. maybe I would frame it as
- changing device/dtype after running forward (and then running backward) is not allowed
- please change device/dtype before running forward
because if you tell me to swap .cpu() with backward, the changes don't take effect unless the user also wanted to run a second iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, does the updated phrasing capture this appropriately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-stable), trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…s a reference (#127313) ### Before this PR: `torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1 ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward # torch.utils.swap_tensors(a, b) del out # Calling swap_tensors here would pass torch.utils.swap_tensors(a, b) ``` ### After this PR: `torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad` A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph). ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here is ok torch.utils.swap_tensors(a, b) # If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors ``` ### Application to `nn.Module` This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address #126814 (comment). Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node ```python import torch import torch.nn as nn m = nn.Linear(3, 5) inp = torch.randn(2, 3) out = m(inp) out.sum().backward() m.cpu() ``` Pull Request resolved: #127313 Approved by: https://github.com/soulitzer
Pull Request resolved: #126814 Approved by: https://github.com/JackCaoG, https://github.com/albanD ghstack dependencies: #127313
…y and .to('meta')) (#126819)
Pull Request resolved: #126819
Approved by: https://github.com/albanD
ghstack dependencies: #127313, #126814
|
@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Pull Request resolved: #126814 Approved by: https://github.com/JackCaoG, https://github.com/albanD ghstack dependencies: #127313
) Pull Request resolved: pytorch#126814 Approved by: https://github.com/JackCaoG, https://github.com/albanD ghstack dependencies: pytorch#127313
…y and .to('meta')) (pytorch#126819)
Pull Request resolved: pytorch#126819
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#127313, pytorch#126814
Before this PR:
torch.utils.swap_tensors(a, b)required theuse_countofaandbto be 1After this PR:
torch.utils.swap_tensors(a, b)requires theuse_countofaandbto be 1 or 2 IF the second reference is held byAccumulateGradA pre-hook will be registered on the
AccumulateGradnode so that it will fail if it is called (i.e. if user attempts to backward through the graph).Application to
nn.ModuleThis issue is especially pertinent in context of
nn.Modulewhere parameters will haveAccumulateGradnodes initialized after forward. Specifically, this is intended to address #126814 (comment). Previously, this would fail at them.cpu()but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisonedAccumulateGradnodeStack from ghstack (oldest at bottom):
Differential Revision: D58094197