Skip to content

Conversation

@mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented May 28, 2024

Before this PR:

torch.utils.swap_tensors(a, b) required the use_count of a and b to be 1

a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)

After this PR:

torch.utils.swap_tensors(a, b) requires the use_count of a and b to be 1 or 2 IF the second reference is held by AccumulateGrad

A pre-hook will be registered on the AccumulateGrad node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors

Application to nn.Module

This issue is especially pertinent in context of nn.Module where parameters will have AccumulateGrad nodes initialized after forward. Specifically, this is intended to address #126814 (comment). Previously, this would fail at the m.cpu() but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned AccumulateGrad node

import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()

Stack from ghstack (oldest at bottom):

Differential Revision: D58094197

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127313

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit 6e4efb1 with merge base 5196ef1 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki marked this pull request as draft May 28, 2024 18:51
[ghstack-poisoned]
@albanD albanD removed their request for review May 28, 2024 21:44
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review May 29, 2024 13:22
@mikaylagawarecki mikaylagawarecki marked this pull request as draft May 29, 2024 15:28
[ghstack-poisoned]
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just some more bikeshedding on the error message

raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
"this can happen when you try to run backward on a tensor that was swapped. "
"For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
"this could happen if trying to run backward changing the device or dtype of the module "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"this could happen if trying to run backward changing the device or dtype of the module "
"this could happen if trying to run backward after changing the device or dtype of the module "

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually.. maybe I would frame it as

  1. changing device/dtype after running forward (and then running backward) is not allowed
  2. please change device/dtype before running forward

because if you tell me to swap .cpu() with backward, the changes don't take effect unless the user also wanted to run a second iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, does the updated phrasing capture this appropriately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

[ghstack-poisoned]
@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 29, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-stable), trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki added the topic: improvements topic category label May 30, 2024
@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request May 30, 2024
…s a reference (#127313)

### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`

A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```

### Application to `nn.Module`

This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address #126814 (comment). Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node

```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```

Pull Request resolved: #127313
Approved by: https://github.com/soulitzer
@mikaylagawarecki
Copy link
Contributor Author

@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nn release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants