Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Nov 11, 2024

Stack from ghstack (oldest at bottom):

Tracking issue: #138399

This PR changes the pow C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)

Update:

Note that this happens only for pow.Tensor_Scalar overloads. Therefore, this PR needed
further 2 modifications:

  • Split the pow ref implementation, making pow.Tensor_Scalar error on mismatching
    output dtypes
  • Create a dispatch for pow when _refs.pow() is called

Update:

Changing the TensorIteratorConfig for pow.Tensor_Scalar was easier and,
after the discussion below, more correct. The solution was to change the
TensorIteratorBase::build_output_borrowing_argument_owning_unary_op function,
setting:

  • cast_common_dtype_to_outputs; and
  • enforce_safe_casting_to_output.

[ghstack-poisoned]
@ysiraichi ysiraichi requested a review from mruberry as a code owner November 11, 2024 18:22
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140287

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0a69f95 with merge base f4ce9ac (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Nov 12, 2024
Tracking issue: #138399

This PR changes the `pow` ref implementation, making its meta kernel consistent with its
CPU implementation. The following example shows the inconsistency between the two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

ghstack-source-id: 5371f04
Pull Request resolved: #140287
@ysiraichi
Copy link
Collaborator Author

While the old version of this PR did make the meta implementation of pow.Tensor_Scalar overloads consistent with its C++ version, I didn't think that was the right way to go.

According to the developer FAQ:

For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source.

Which means that: if pow.Tensor_Scalar does participate in dtype promotion (it does call at::result_type(...)), we can safe copy to an output tensor with a different dtype (not lower kind, though). Which means that the C++ meta function should be changed (current state of this PR).

Note: while this implementation works, we are running the whole computation in the output argument dtype, instead of just safe-copying the results. Technically, I think this does not reflect on the developer FAQ specification. For that, we would need to change the TensorIterator a bit.

@ezyang @amjames Any thoughts on this?

@amjames
Copy link
Collaborator

amjames commented Nov 12, 2024

I agree that the behavior of running the entire op in the wider dtype does not seem to match the spec outlined in the developer FAQ, however I suppose it does preserve the user visible behavior. On the other hand, modifying TensorIterator to handle this situation in a more correct way would be a fairly invasive change, and may have lots of side effects.

I suppose we could try making that modification to TensorIterator here and see what happens, but what you have here looks okay to me.

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Nov 12, 2024
Tracking issue: #138399

This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

ghstack-source-id: 67a3ec5
Pull Request resolved: #140287
@ysiraichi
Copy link
Collaborator Author

I think I figured out how to change TensorIterator without being invasive! As far as I understand build_output_borrowing_argument_owning_unary_op function is used only for pow, which is very convenient. So, I just had to tweak it, so that it would do a safe copy to the output tensor.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beh looks like this is a semantics change though :P

[ghstack-poisoned]
[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Nov 18, 2024
Tracking issue: #138399

This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

ghstack-source-id: b902a55
Pull Request resolved: #140287
@ysiraichi
Copy link
Collaborator Author

While it does change the semantics, in the sense that we don't expect the output tensor to be of an exact dtype, I think it brings us closer to the out= specification.

@ysiraichi
Copy link
Collaborator Author

The CI failure is unrelated to this PR.

@ysiraichi
Copy link
Collaborator Author

@pytorchbot merge -i

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 19, 2024

This PR needs to be approved by an authorized maintainer before merge.

@ysiraichi
Copy link
Collaborator Author

@ezyang Could you take a look at this PR? I have commented on the semantic change. I don't think this would break anything, though, since the change only makes pow less strict regarding output tensor dtype.

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2024

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/ysiraichi/71/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/140287)

pytorchmergebot pushed a commit that referenced this pull request Nov 20, 2024
Tracking issue: #138399

This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

ghstack-source-id: ce865dd
Pull Request resolved: #140287
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@ysiraichi
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Tracking issue: pytorch#138399

This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

**~Update:~**

~Note that this happens only for `pow.Tensor_Scalar` overloads. Therefore, this PR needed
further 2 modifications:~

- ~Split the `pow` ref implementation, making `pow.Tensor_Scalar` error on mismatching
output dtypes~
- ~Create a dispatch for `pow` when `_refs.pow()` is called~

**Update:**

Changing the `TensorIteratorConfig` for `pow.Tensor_Scalar` was easier and,
after the discussion below, more correct. The solution was to change the
`TensorIteratorBase::build_output_borrowing_argument_owning_unary_op` function,
setting:

- `cast_common_dtype_to_outputs`; and
- `enforce_safe_casting_to_output`.

Pull Request resolved: pytorch#140287
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/ysiraichi/71/head branch December 21, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants