-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Port bmm and baddbmm from TH to ATen #42553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 909b366 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 162 times. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
zasdfgbnm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not finished yet. Will post more comment later.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
|
Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices). |
|
@anjali411 Could you please rebase? Looks like there are lots of flaky tests. |
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
test/test_torch.py
Outdated
| @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1") | ||
| @onlyOnCPUAndCUDA | ||
| @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) | ||
| @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @ngimel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498
So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16 and torch.bfloat16 in a follow-up PR and leave this one as is.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
zasdfgbnm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for working on this!
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/46/base #42553 +/- ##
=====================================================
Coverage 81.22% 81.22%
=====================================================
Files 1837 1837
Lines 198087 198087
=====================================================
+ Hits 160893 160897 +4
+ Misses 37194 37190 -4 |
|
@anjali411 merged this pull request in e1ee3bf. |
Summary: Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests. Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`. Pull Request resolved: #47910 Reviewed By: zhangguanheng66 Differential Revision: D25052130 Pulled By: mruberry fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
| /* LEVEL 3 BLAS FUNCTIONS */ | ||
|
|
||
| #ifndef __HIP_PLATFORM_HCC__ | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No harm done, workaround is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xwang233 no my bad! we should fix that to avoid confusion in future
Stack from ghstack:
Ports
torch.bmmandtorch.baddbmmfrom TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.Closes #24539
Differential Revision: D24893511