Skip to content

Conversation

@nikitaved
Copy link
Collaborator

As per title. A minor fix required to make it available for the CPU (fmod does not support complex).
For CUDA requires #45898 .

@nikitaved nikitaved added the module: complex Related to complex number support in PyTorch label Oct 7, 2020
@nikitaved nikitaved requested a review from anjali411 October 7, 2020 17:46
@nikitaved
Copy link
Collaborator Author

Once #45898 is merged, I plan to extend the existing tests to complex types.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 7, 2020
@codecov
Copy link

codecov bot commented Oct 7, 2020

Codecov Report

Merging #45980 into master will increase coverage by 24.78%.
The diff coverage is 100.00%.

@@             Coverage Diff             @@
##           master   #45980       +/-   ##
===========================================
+ Coverage   36.02%   60.81%   +24.78%     
===========================================
  Files         437     2749     +2312     
  Lines       55230   254094   +198864     
===========================================
+ Hits        19898   154530   +134632     
- Misses      35332    99564    +64232     


# no CUDA LU yet, used for torch.det())
@onlyCPU
@skipCPUIfNoLapack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry should this test go in test_unary_ops or test_linalg?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_linalg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright then @nikitaved we should just update this test:

def test_det(self, device, dtype):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @nikitaved please merge this test with the test in test_linalg.py

Copy link
Collaborator Author

@nikitaved nikitaved Oct 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, when I try moving things into test_linalg.py, some checks fail, unlike for test_torch.py.
For example, the test for upper/lower-diagonal matrices fail, as the product of diagonal elements does not match the det result. No problem in test_torch.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All fine, the tests are for quite large tensors, so we run into overflow issues.

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikitaved this looks good overall. could you update the autograd test for det to run for complex? also let's just extend the test for det in test_linalg.py

auto num_exchanges = (at::arange(1, n + 1, pivs.options().dtype(at::kLong)) != pivs.to(at::kLong))
.sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong).fmod_(2);
// NB: the `.contiguous()` call is added due to the bug in `.prod()` as reported in
// issue #https://github.com/pytorch/pytorch/issues/34061
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#34061 looks to be fixed, should we remove this workaround?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to @ngimel 's point, let's remove this workaround since #34061 is fixed

@nikitaved
Copy link
Collaborator Author

nikitaved commented Oct 19, 2020

Thank you very much for your input! I updated the PR. The tests from common_methods_invocations.py are disabled as backward for SVD is not yet implemented for complex.

@anjali411 , could you please tell whether there is still a limitation of autograd not being able to process functions with several outputs of complex type?

@dr-ci
Copy link

dr-ci bot commented Oct 19, 2020

💊 CI failures summary and remediations

As of commit ae6ebd4 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 21 times.

Comment on lines +91 to 93
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cuda", [&]() {
prod_functor<scalar_t>{}(iter);
});
Copy link
Collaborator Author

@nikitaved nikitaved Oct 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests for complex prod. Does it make sense to do it in a separate PR? The backward has to get modified accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: backward for prod will not work for complex inputs, because it relies on nonzero, which is not yet implemented for the complex numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this sounds good let's update the backward formula in the next PR. This PR #46596 enables backward for nonzero for complex. In general, to enable backward for an op for complex, you would need to add an entry in GRADIENT_IMPLEMENTED_FOR_COMPLEX in gen_variable.py.

@anjali411
Copy link
Contributor

could you please tell whether there is still a limitation of autograd not being able to process functions with several outputs of complex type?

No I don't think so. Please let me know if you are running into any issues though

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

Hi @nikitaved!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@anjali411
Copy link
Contributor

@nikitaved can you sign the CLA, rebase the PR and remove the .contiguous() workaround please?

@nikitaved
Copy link
Collaborator Author

@anjali411 , sorry, force-pushed from the wrong machine, now it should be fixed.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in 8a3728c.

@anjali411
Copy link
Contributor

@nikitaved would you like to create a follow-up PR to add complex autograd support for torch.det and torch.prod?

@nikitaved
Copy link
Collaborator Author

nikitaved commented Nov 17, 2020

@anjali411 , turns out torch.det is a rabbit hole, its backward testing does depend on prod, cumsum, cumprod, index_select, index_add and counting (gradgradcheck I presume). Since all these methods have quite an extensive test coverage, I will submit one PR per each.

@mruberry
Copy link
Collaborator

Hey @nikitaved! @anjali411 is actually on pto at the moment, but that plan makes a lot of sense.

I added complex autograd support for torch.linalg.det as a task in the linear algebra tacking issue (#42666) with a link to your comment. Let's track the work there.

@nikitaved
Copy link
Collaborator Author

@mruberry , thank you, will update the issue there. Meanwhile, can I request your reviews on these PRs?

@nikitaved
Copy link
Collaborator Author

nikitaved commented Nov 18, 2020

It turns out there is a circular relationship between dependencies, so, it is going to be a single PR.

@mruberry
Copy link
Collaborator

@mruberry , thank you, will update the issue there. Meanwhile, can I request your reviews on these PRs?

Of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: complex Related to complex number support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants