-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add torch.dot for complex tensors #42745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 1ff598c (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 64 times. |
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
aten/src/ATen/native/BlasKernel.cpp
Outdated
|
|
||
| #if AT_BUILD_WITH_BLAS() | ||
| extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy); | ||
| extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"...you should not expect to use Fortran functions that return types such as COMPLEX or COMPLEX*16. Write a SUBROUTINE interface to your Fortran function instead, and then invoke it as a void function from C or C++."
[ghstack-poisoned]
|
|
||
| template <> | ||
| void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) { | ||
| TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I wonder if we shouldn't have some methods on c10::complex for doing pointery conversions like this. It would be nice to not have to be slinging reinterpret cast everywhere. (No action needed for PR)
[ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
|
@anjali411 as discussed offline, the reason CPU results for the ROCm build are failing is because diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp
index ef05cb8..1fe8a73 100644
--- a/aten/src/ATen/native/BlasKernel.cpp
+++ b/aten/src/ATen/native/BlasKernel.cpp
@@ -17,9 +17,6 @@ extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int
# define ffloat float
#endif
-extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
-extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
-extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
#ifdef BLAS_USE_CBLAS_DOT
extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
@@ -40,6 +37,10 @@ static inline void zdotu_(std::complex<double> *res, const int *n, const std::co
cblas_zdotu_sub(*n, x, *incx, y, *incy, res);
}
#endif // THBlas_cblas_dot_
+#else // BLAS_USE_CBLAS_DOT
+extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
+extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
+extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
#endif // BLAS_USE_CBLAS_DOT
#endif // AT_BUILD_WITH_BLAS |
|
BLAS development began in fortran, and the function calling convention differs between compilers and compiler flags, making it difficult to get correct when calling the fortran function from C. Sometimes the complex value is returned as a hidden positional argument as in your extern declaration, and sometimes it is returned like a regular C function call return. Since the openblas symbols were getting linked, and their extern declaration was incorrect, the function was getting called but there was a mismatch between the expected function signature and how it was declared. |
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
|
@anjali411 merged this pull request in aab6660. |
| TH_EXTERNC void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); | ||
| TH_EXTERNC void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); | ||
|
|
||
| #ifndef THBlas_cblas_dot_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are these symbols being defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in #43148
Stack from ghstack:
Differential Revision: D23056382