-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make torch.det() support complex input.
#45980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Once #45898 is merged, I plan to extend the existing tests to complex types. |
Codecov Report
@@ Coverage Diff @@
## master #45980 +/- ##
===========================================
+ Coverage 36.02% 60.81% +24.78%
===========================================
Files 437 2749 +2312
Lines 55230 254094 +198864
===========================================
+ Hits 19898 154530 +134632
- Misses 35332 99564 +64232 |
test/test_torch.py
Outdated
|
|
||
| # no CUDA LU yet, used for torch.det()) | ||
| @onlyCPU | ||
| @skipCPUIfNoLapack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry should this test go in test_unary_ops or test_linalg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_linalg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright then @nikitaved we should just update this test:
Line 162 in 8e8fb85
| def test_det(self, device, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @nikitaved please merge this test with the test in test_linalg.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, when I try moving things into test_linalg.py, some checks fail, unlike for test_torch.py.
For example, the test for upper/lower-diagonal matrices fail, as the product of diagonal elements does not match the det result. No problem in test_torch.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All fine, the tests are for quite large tensors, so we run into overflow issues.
anjali411
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikitaved this looks good overall. could you update the autograd test for det to run for complex? also let's just extend the test for det in test_linalg.py
| auto num_exchanges = (at::arange(1, n + 1, pivs.options().dtype(at::kLong)) != pivs.to(at::kLong)) | ||
| .sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong).fmod_(2); | ||
| // NB: the `.contiguous()` call is added due to the bug in `.prod()` as reported in | ||
| // issue #https://github.com/pytorch/pytorch/issues/34061 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#34061 looks to be fixed, should we remove this workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Thank you very much for your input! I updated the PR. The tests from @anjali411 , could you please tell whether there is still a limitation of autograd not being able to process functions with several outputs of complex type? |
💊 CI failures summary and remediationsAs of commit ae6ebd4 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 21 times. |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cuda", [&]() { | ||
| prod_functor<scalar_t>{}(iter); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tests for complex prod. Does it make sense to do it in a separate PR? The backward has to get modified accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: backward for prod will not work for complex inputs, because it relies on nonzero, which is not yet implemented for the complex numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this sounds good let's update the backward formula in the next PR. This PR #46596 enables backward for nonzero for complex. In general, to enable backward for an op for complex, you would need to add an entry in GRADIENT_IMPLEMENTED_FOR_COMPLEX in gen_variable.py.
No I don't think so. Please let me know if you are running into any issues though |
anjali411
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Hi @nikitaved! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
@nikitaved can you sign the CLA, rebase the PR and remove the |
b5d9529 to
625191f
Compare
625191f to
ae6ebd4
Compare
|
@anjali411 , sorry, force-pushed from the wrong machine, now it should be fixed. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in 8a3728c. |
|
@nikitaved would you like to create a follow-up PR to add complex autograd support for |
|
@anjali411 , turns out |
|
Hey @nikitaved! @anjali411 is actually on pto at the moment, but that plan makes a lot of sense. I added complex autograd support for torch.linalg.det as a task in the linear algebra tacking issue (#42666) with a link to your comment. Let's track the work there. |
|
@mruberry , thank you, will update the issue there. Meanwhile, can I request your reviews on these PRs? |
|
It turns out there is a circular relationship between dependencies, so, it is going to be a single PR. |
Of course. |
As per title. A minor fix required to make it available for the CPU (
fmoddoes not support complex).For CUDA requires #45898 .