Skip to content

Commit ce32d03

Browse files
committed
update wrapper signature; raise on CUDA 10.0
1 parent 3515fb9 commit ce32d03

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

cupy/cusolver.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ cdef extern from '../cupy_backends/cupy_complex.h':
4242
cdef extern from '../cupy_backends/cupy_lapack.h' nogil:
4343
int gesvd_loop[T](
4444
intptr_t handle, char jobu, char jobvt, int m, int n, intptr_t A,
45-
int lda, intptr_t s_ptr, intptr_t u_ptr, int ldu, intptr_t vt_ptr,
46-
int ldvt, intptr_t w_ptr, int buffersize, intptr_t info_ptr,
45+
intptr_t s_ptr, intptr_t u_ptr, intptr_t vt_ptr,
46+
intptr_t w_ptr, int buffersize, intptr_t info_ptr,
4747
int batch_size)
4848

49-
ctypedef int (*gesvd_ptr)(intptr_t, char, char, int, int, intptr_t, # noqa
50-
int, intptr_t, intptr_t, int, intptr_t,
51-
int, intptr_t, int, intptr_t, int) nogil
49+
ctypedef int(*gesvd_ptr)(intptr_t, char, char, int, int, intptr_t,
50+
intptr_t, intptr_t, intptr_t,
51+
intptr_t, int, intptr_t, int) nogil
5252

5353

5454
_available_cuda_version = {
@@ -272,6 +272,10 @@ cpdef _gesvd_batched(a, a_dtype, full_matrices, compute_uv, overwrite_a):
272272
raise RuntimeError("This function is disabled on HIP as "
273273
"it is not needed")
274274

275+
if runtime.runtimeGetVersion() == 10000:
276+
# see https://github.com/cupy/cupy/pull/4628#issuecomment-780311925
277+
raise RuntimeError("batched gesvd is buggy on CUDA 10.0")
278+
275279
# TODO(leofang): try overlapping using a small stream pool?
276280

277281
cdef ndarray x, s, u, vt, dev_info
@@ -344,8 +348,8 @@ cpdef _gesvd_batched(a, a_dtype, full_matrices, compute_uv, overwrite_a):
344348
# the loop starts here, with gil released to reduce overhead
345349
with nogil:
346350
status = gesvd(
347-
handle, job_u, job_vt, m, n, a_ptr, m, s_ptr,
348-
u_ptr, m, vt_ptr, n,
351+
handle, job_u, job_vt, m, n, a_ptr,
352+
s_ptr, u_ptr, vt_ptr,
349353
w_ptr, buffersize, info_ptr, batch_size)
350354
if status != 0:
351355
raise _cusolver.CUSOLVERError(status)

cupy_backends/cupy_lapack.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ template<> struct gesvd_func<cuDoubleComplex, double> { gesvd<cuDoubleComplex, d
3333
template<typename T>
3434
int gesvd_loop(
3535
intptr_t handle, char jobu, char jobvt, int m, int n, intptr_t a_ptr,
36-
int lda, intptr_t s_ptr, intptr_t u_ptr, int ldu, intptr_t vt_ptr,
37-
int ldvt, intptr_t w_ptr, int buffersize, intptr_t info_ptr,
36+
intptr_t s_ptr, intptr_t u_ptr, intptr_t vt_ptr,
37+
intptr_t w_ptr, int buffersize, intptr_t info_ptr,
3838
int batch_size) {
3939
/*
4040
* Assumptions:
@@ -60,8 +60,8 @@ int gesvd_loop(
6060
for (int i=0; i<batch_size; i++) {
6161
// setting rwork to NULL as we don't need it
6262
status = func(
63-
reinterpret_cast<cusolverDnHandle_t>(handle), jobu, jobvt, m, n, A, lda,
64-
S, U, ldu, VT, ldvt, Work, buffersize, NULL, devInfo);
63+
reinterpret_cast<cusolverDnHandle_t>(handle), jobu, jobvt, m, n, A, m,
64+
S, U, m, VT, n, Work, buffersize, NULL, devInfo);
6565
if (status != 0) break;
6666
A += m * n;
6767
S += k;

0 commit comments

Comments
 (0)