-
Notifications
You must be signed in to change notification settings - Fork 26.3k
pow: fix meta function output argument dtype check.
#140287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140287
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 0a69f95 with merge base f4ce9ac ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Tracking issue: #138399 This PR changes the `pow` ref implementation, making its meta kernel consistent with its CPU implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` ghstack-source-id: 5371f04 Pull Request resolved: #140287
|
While the old version of this PR did make the meta implementation of According to the developer FAQ:
Which means that: if Note: while this implementation works, we are running the whole computation in the output argument dtype, instead of just safe-copying the results. Technically, I think this does not reflect on the developer FAQ specification. For that, we would need to change the |
|
I agree that the behavior of running the entire op in the wider dtype does not seem to match the spec outlined in the developer FAQ, however I suppose it does preserve the user visible behavior. On the other hand, modifying I suppose we could try making that modification to |
Tracking issue: #138399 This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with its Python ref implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` ghstack-source-id: 67a3ec5 Pull Request resolved: #140287
|
I think I figured out how to change |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beh looks like this is a semantics change though :P
Tracking issue: #138399 This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with its Python ref implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` ghstack-source-id: b902a55 Pull Request resolved: #140287
|
While it does change the semantics, in the sense that we don't expect the output tensor to be of an exact dtype, I think it brings us closer to the |
|
The CI failure is unrelated to this PR. |
|
@pytorchbot merge -i |
|
This PR needs to be approved by an authorized maintainer before merge. |
|
@pytorchbot merge -r |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
Tracking issue: #138399 This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with its Python ref implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` ghstack-source-id: ce865dd Pull Request resolved: #140287
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Tracking issue: pytorch#138399 This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with its Python ref implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` **~Update:~** ~Note that this happens only for `pow.Tensor_Scalar` overloads. Therefore, this PR needed further 2 modifications:~ - ~Split the `pow` ref implementation, making `pow.Tensor_Scalar` error on mismatching output dtypes~ - ~Create a dispatch for `pow` when `_refs.pow()` is called~ **Update:** Changing the `TensorIteratorConfig` for `pow.Tensor_Scalar` was easier and, after the discussion below, more correct. The solution was to change the `TensorIteratorBase::build_output_borrowing_argument_owning_unary_op` function, setting: - `cast_common_dtype_to_outputs`; and - `enforce_safe_casting_to_output`. Pull Request resolved: pytorch#140287 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
pow: fix meta function output argument dtype check. #140287Tracking issue: #138399
This PR changes the
powC++ implementation, making its C++ meta kernel consistent withits Python ref implementation. The following example shows the inconsistency between the
two:
Update:Note that this happens only forpow.Tensor_Scalaroverloads. Therefore, this PR neededfurther 2 modifications:
Split thepowref implementation, makingpow.Tensor_Scalarerror on mismatchingoutput dtypes
Create a dispatch forpowwhen_refs.pow()is calledUpdate:
Changing the
TensorIteratorConfigforpow.Tensor_Scalarwas easier and,after the discussion below, more correct. The solution was to change the
TensorIteratorBase::build_output_borrowing_argument_owning_unary_opfunction,setting:
cast_common_dtype_to_outputs; andenforce_safe_casting_to_output.