-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add complex support for torch.mean [CUDA] #47048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add complex support for torch.mean [CUDA] #47048
Conversation
💊 CI failures summary and remediationsAs of commit de87d22 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
d618111 to
b0de1be
Compare
Codecov Report
@@ Coverage Diff @@
## master #47048 +/- ##
===========================================
+ Coverage 35.95% 53.27% +17.31%
===========================================
Files 438 2747 +2309
Lines 55454 254304 +198850
===========================================
+ Hits 19939 135476 +115537
- Misses 35515 118828 +83313 |
|
@anjali411 Thanks so much for reviewing this PR. |
| }); | ||
| } | ||
|
|
||
| template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting acc_t = typename c10::scalar_value_type<scalar_t>::type should resolve the issue here
c10::scalar_value_type<scalar_t>::type returns scalar_t for all non-complex dtypes and returns T for c10::complex<T>.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the tip. The latest code now manipulates c10::scalar_value_type<acc_t>::type to get the type of the factor, the overload functions for complex numbers are not needed.
Hi @RockingJavaBean I think we shouldn't need to define overload functions for complex types, after the change I suggested in my comment. But this PR looks great overall, and should be ready to merge after that change! |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| lambda n, d: n.mean(d), | ||
| use_integral=False) | ||
| use_integral=False, | ||
| use_complex=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't run on CUDA. can you please extend the test for mean in tensor_op_tests to also test complex dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out, the tests for complex dtypes are added to tensor_op_tests.
anjali411
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's update the CUDA test for mean to test complex dtypes as well
|
@anjali411 I'm really grateful for your tip on |
anjali411
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm and the windows test failure is an upstream test failure
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 thanks so much for reviewing this PR, the CUDA tests for |
|
@RockingJavaBean can you please rebase? |
|
@anjali411 thank you so much for the kind reminder, I just rebased this PR with the latest code. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in f90da88. |
Fixes #46982