Skip to content

Create a c10::complex #35284

@zasdfgbnm

Description

@zasdfgbnm

For the list of PRs, see #35284 (comment)

It is very unfortunate that, std::complex does not work on CUDA, so on CUDA, all complex operators are implemented as converting std::complex to thrust::complex and then work on that. But this approach has its issue too:

First, thrust::complex itself has some problem, for example, on CUDA9 and ROCm, it does not have a default constructor, which makes it very annoying when working with complex, and it could take a long time to debug such build errors. For example, there is a ATen/native/cuda/CUDA9Workarounds.cuh created to workaround this problem.

Second, this makes the dtype<-->scalar_t diverge on CPU and CUDA, this is really bad, because there are helpers in c10/core/ScalarType.h and c10/util/TypeCast.h that makes use of this assumption, and I believe most of our engineers also have this assumption in mind. Forgetting this fact could lead to subtle bugs, for example, I just noticed that the needs_dynamic_casting in ATen/native/cuda/Loops.cuh always returns true when the dtype is complex, because thrust::complex and std::complex are not the same type.

Such kind of behavior could lead to subtle bugs, for example: the copy kernel can not convert complex: #35225, also this PR #34749 also fails on complex, but only on Linux (Windows is fine). The issue and PR I just mentioned fails with unspecified launch failure. Although I have not been able to find the exact problem of the issue and PR I've just mentioned, the behavior of #34749 makes me feel like the issue is on CUDA itself instead of PyTorch.

Considering that we already have a at::ComplexHalf, I don't think having a c10::complex is a bad idea. And by creating our own implementation of complex at c10, we will have more control on it, hopefully it will solve these issues more cleanly.

It is not hard to implement c10::complex, for arithmetic and real, imag, it's trivial. For functions like exp, log, sin, it can be simply implemented as casting to std::complex or thrust::complex and then call the library.

After creating a c10::complex, all the std::complex in c10/core/ScalarType.h will be replaced with c10::complex.

The only problem that I can think of is backward compatibility. If some user has write something like

template<typename T>
struct is_complex: public false_type {};
template<typename T>
struct is_complex<std::complex<T>> : public true_type {};

it will not work.

cc @ezyang @anjali411 @dylanbespalko

Metadata

Metadata

Assignees

Labels

module: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions