Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Aug 28, 2024

Stack from ghstack (oldest at bottom):

aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
  • cse would de-duplicate the empty nodes
  • reinplacing would add an additional clone (because it can't write to
    both tensors at the same time)
  • the clone lowers into a new buffer + a copy_ kernel
  • the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
    buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:

  • new test (the old number was 27, the new number is 21, so this PR
    helped).

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the "clone" is unnecessary - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134703

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 84eb374 with merge base 55236d0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zou3519 added a commit that referenced this pull request Aug 28, 2024
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the "clone" is unnecessary - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

ghstack-source-id: 37a5d74
Pull Request resolved: #134703
@zou3519 zou3519 added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Aug 28, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Aug 29, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@zou3519
Copy link
Contributor Author

zou3519 commented Aug 29, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

Pull Request resolved: pytorch#134703
Approved by: https://github.com/yf225
ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491
@github-actions github-actions bot deleted the gh/zou3519/1058/head branch October 2, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants