-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Complex support on GPU for dynamic casting #29612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex support on GPU for dynamic casting #29612
Conversation
c10/util/TypeCast.h
Outdated
| template<typename dest_t, typename complex_src_t> | ||
| struct MaybeReal { | ||
| static C10_HOST_DEVICE inline typename complex_src_t::value_type maybe_real(complex_src_t src) { | ||
| return src.real(); | ||
| } | ||
| }; | ||
|
|
||
| template<typename complex_src_t> | ||
| struct MaybeReal<std::complex<float>, complex_src_t> { | ||
| static C10_HOST_DEVICE inline complex_src_t maybe_real(complex_src_t src) { | ||
| return src; | ||
| } | ||
| }; | ||
|
|
||
| template<typename complex_src_t> | ||
| struct MaybeReal<std::complex<double>, complex_src_t> { | ||
| static C10_HOST_DEVICE inline complex_src_t maybe_real(complex_src_t src) { | ||
| return src; | ||
| } | ||
| }; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding MaybeReal. I'm surprised that other scalar_t types like Bool, Half, BFFloat16 can call the scalar_t::real() method, but I guess it worked for std::real(scalar_t)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dylanbespalko No, the way it is used makes sure that complex_src_t could only be std::complex<something_t>.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, you could really help me out by adding this change to c10/util/TypeCast.h::line48::52. I think it should be:
template <typename dest_t, typename src_t>
C10_HOST_DEVICE inline dest_t static_cast_with_inter_type(src_t src) {
return static_cast<dest_t>(
static_cast<inter_copy_type_t<dest_t>>(MaybeReal<dest_t, src_t>(src)));
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Clever...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dylanbespalko Thanks for the review. That makes great sense. And I will add that.
|
I would prefer if you accept these changes before mine. I will pull the update from master. |
|
@dylanbespalko I have moved logics into |
Looks good to me. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex) - [x] Replaced std:real(a) with a.real() in kernel level code. - [x] Fixed Vec256_base implementation of complex ops so that it works correctly on Non-AVX devices. - [ ] Clean up CopyKernel after #29612 is approved. zasdfgbnm is fixing this issue in #29612. This should be added first. cc: iotamudelta, ezyang, bddppq Pull Request resolved: #29607 Differential Revision: D18451046 Pulled By: ezyang fbshipit-source-id: b9dcd8e25e91cab13bd131b070d027b090cdedc9
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex) - [x] Replaced std:real(a) with a.real() in kernel level code. - [x] Fixed Vec256_base implementation of complex ops so that it works correctly on Non-AVX devices. - [ ] Clean up CopyKernel after pytorch/pytorch#29612 is approved. zasdfgbnm is fixing this issue in pytorch/pytorch#29612. This should be added first. cc: iotamudelta, ezyang, bddppq Pull Request resolved: pytorch/pytorch#29607 Differential Revision: D18451046 Pulled By: ezyang fbshipit-source-id: b9dcd8e25e91cab13bd131b070d027b090cdedc9
Summary: After #29612 get merged, `static_cast_with_inter_type` can now automatically convert complex types to its real values, therefore there is no need to do it inside copy kernel. This should wait until #29612 get merged, otherwise it won't pass CI. Pull Request resolved: #29631 Differential Revision: D18485676 Pulled By: ezyang fbshipit-source-id: 0bbfd551e3d3010f87eef0fce23a1f8a094b7d31
Summary: After pytorch/pytorch#29612 get merged, `static_cast_with_inter_type` can now automatically convert complex types to its real values, therefore there is no need to do it inside copy kernel. This should wait until pytorch/pytorch#29612 get merged, otherwise it won't pass CI. Pull Request resolved: pytorch/pytorch#29631 Differential Revision: D18485676 Pulled By: ezyang fbshipit-source-id: 0bbfd551e3d3010f87eef0fce23a1f8a094b7d31
Currently, the dynamic casting mechanism is implemented assuming no support of complex on GPU. This will no longer be true in the soon future.
#29547 could clear some clang warning but the complex support on GPU is still not complete:
This PR is what should be done for type promotion in order to add support to complex dtype on GPU, as suggested in #755 (comment)
Note that what is newly added here in this PR is not tested due to the lack of basic support of complex dtypes (I can not construct a complex tensor). But his PR shouldn't break any existing part of PyTorch.
For the merge this PR, consider two options: