-
Notifications
You must be signed in to change notification settings - Fork 26.3k
torch.sgn for complex tensors #39955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
|
cc. @dylanbespalko for the Vec256 changes |
💊 CI failures summary and remediationsAs of commit 0fed8c8 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 95 times. |
| ['sigmoid', torch.sigmoid], | ||
| ['sigmoid_', torch.sigmoid_], | ||
| ['sign', torch.sign], | ||
| ['sgn', torch.sign], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.sgn?
torch/_torch_docs.py
Outdated
| Example:: | ||
| >>> x=torch.randn(4, dtype=torch.cfloat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example doesn't really demonstrate the function because of the random values. Maybe pick some angles, like 0, 45, 90 degrees?
test/test_torch.py
Outdated
| cos_angle = angle.cos() | ||
| sin_angle = angle.sin() | ||
| expected = cos_angle + 1j * sin_angle | ||
| self.assertEqual(x.sgn(), expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this instead assert that x.sgn has the same angle as x.angle and that the absolute value is everywhere one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sure
|
Added some notes. This also still needs to update torch/_overrides.py, docs/source/tensors.rst, docs/source/torch.rst, docs/source/name_inference.rst, aten/src/ATen/core/aten_interned_strings.h and tools/derivatives.yaml. Can this also be tested by TestTensorDeviceOps and TestTorchMathOps? |
[ghstack-poisoned]
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't compute normalized vector using trig functions
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor doc nits but otherwise this looks good to me!
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this still needs the gradients for the real case to be tested, even if you postpone the check for the complex version.
I was thinking of waiting to merge this PR until #43208 is merged and add tests for sgn. |
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/34/base #39955 +/- ##
========================================================
- Coverage 67.86% 67.85% -0.01%
========================================================
Files 384 384
Lines 50026 50029 +3
========================================================
Hits 33948 33948
- Misses 16078 16081 +3
Continue to review full report at Codecov.
|
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
|
@anjali411 merged this pull request in 58b6ab6. |
ghstack-source-id: e4ab67a Pull Request resolved: pytorch/pytorch#39955
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
Summary: Pull Request resolved: #45461 This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D23998156 Pulled By: anjali411 fbshipit-source-id: 370eb07fe56ac84dd8e2233ef7bf3a3eb8aeb179
Stack from ghstack:
resolves #36323 by adding
torch.sgnfor complex tensors.torch.sgnreturnsx/abs(x)forx != 0and returns0 + 0jforx==0also updates the backward definition of
torch.div,torch.absDifferential Revision: D23460526