-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[CPU] Added torch.bmm for complex tensors #42383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
| @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) | ||
| @tf32_on_and_off(0.01) | ||
| def test_mm(self, device, dtype): | ||
| def _test_mm(n, m, p, dtype, genf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/test_torch.py
Outdated
| 1e-1, 1e-1, 1e-5, _float_types2), | ||
| ('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], | ||
| 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]), | ||
| 1e-1, 1e-1, 1e-4, _complex_and_float_types2, _cpu_types + _complex_types, True, [tf32_on_and_off(0.005)]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry IIUC, _cpu_types is empty list, and _cpu_types + _complex_types would mean run CPU test only with complex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Boy does this set of tests need to be entirely rewritten.
Your understanding is correct. Note that "testing" on CPU here just means verifying that the function and method variants of the op produce the same value. There's no check those values are correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that the cpu tests in TestTensorDeviceOps only compare the function and method result which basically doesn't test anything for a lot of ops. will remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"testing" on cpu means results in low precision (bfloat16 or whatever _cpu_types is) is compared to result in fp32 precision. For _cpu_types=fp32 it would indeed just compare function and method.
💊 CI failures summary and remediationsAs of commit 42de72f (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 16 times. |
Test Plan - Updated existing tests to run for complex dtypes as well. Also added tests for `torch.addmm`, `torch.badmm` [ghstack-poisoned]
| np_out = torch.full((n, p), float('nan'), device=device) | ||
| self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out)) | ||
|
|
||
| if dtype.is_complex and device.startswith('cuda'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.device_type == 'cuda' or torch.device(device).device_type
zasdfgbnm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
|
|
||
| @onlyCPU | ||
| @dtypes(torch.float) | ||
| @dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Could we unify
torch.float32, torch.float64, torch.cfloat, torch.cdouble
vs
torch.testing.get_all_complex_dtypes() + [torch.float, torch.double]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean, replace all instances of torch.float32, torch.float64, torch.cfloat, torch.cdouble with torch.testing.get_all_complex_dtypes() + [torch.float, torch.double] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the opposite:
torch.float32, torch.float64, torch.cfloat, torch.cdouble
looks better than
*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to use torch.testing.get_all_complex_dtypes() since in case we add more complex dtypes in future (for example, support for torch.chalf or others), we wouldn't have to update the tests one by one to add tests for that complex dtype. Just including it in torch.testing.get_all_complex_dtypes() would suffice.
|
I think you need to rebase to fix the docker image not found error. |
Test Plan - Updated existing tests to run for complex dtypes as well. Also added tests for `torch.addmm`, `torch.badmm` [ghstack-poisoned]
Test Plan - Updated existing tests to run for complex dtypes as well. Also added tests for `torch.addmm`, `torch.badmm` Differential Revision: [D22960339](https://our.internmc.facebook.com/intern/diff/D22960339) [ghstack-poisoned]
Test Plan - Updated existing tests to run for complex dtypes as well. Also added tests for `torch.addmm`, `torch.badmm` Differential Revision: [D22960339](https://our.internmc.facebook.com/intern/diff/D22960339) [ghstack-poisoned]
|
@anjali411 merged this pull request in c9346ad. |
|
@jeffdaily currently cgemm/zgemm is disabled on rocm, but rocblas actually supports it - what would it take to enable it? |
Summary: Revert "Skips some complex tests on ROCm (#42759)". This reverts commit 55b1706. Use new cuda_to_hip_mappings.py from #43004. Fixes #42383 (comment) CC sunway513 Pull Request resolved: #43744 Reviewed By: glaringlee Differential Revision: D23391263 Pulled By: ngimel fbshipit-source-id: ddf734cea3ba69c24f0d79cf1b87c05cdb45ec3d

Stack from ghstack:
Test Plan - Updated existing tests to run for complex dtypes as well.
Also added tests for
torch.addmm,torch.badmmDifferential Revision: D22960339