Conversation
|
... |
|
I should have read carefully the discussion in #2337 first 😅 Looks like I was duplicating the work... |
I am testing |
Though the result is still incorrect...
| work = _cupy.empty(lwork, dtype=a.dtype) | ||
| info = _cupy.empty(1, dtype=_numpy.int32) | ||
| solver(handle, jobz, m, n, a.data.ptr, lda, s.data.ptr, | ||
| info = _cupy.empty(batch_size, dtype=_numpy.int32) |
There was a problem hiding this comment.
This is a bug: both cuSOLVER and rocSOLVER need the info array to be of the batch size. I thought we fixed it earlier...?
...but runs almost as fast as a pure Python loop...
avoid slicing overhead, check info in the end, etc
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
|
On CUDA: the manual looping follows what's done in JAX: https://github.com/google/jax/blob/8bf3f032989caddf05c702acc1ff5353abbe72d2/jaxlib/cusolver.cc#L937-L947 |
This comment has been minimized.
This comment has been minimized.
|
Jenkins, test this please |
|
@anaruse can you please take a look at this? 😇 |
|
Jenkins CI test (for commit 3515fb9, target branch master) succeeded! |
| A += m * n; | ||
| S += k; | ||
| U += m * m; | ||
| VT += n * n; |
There was a problem hiding this comment.
Is it assumed that (lda, ldu, ldvt) == (m, m, n)?
Keeping the signature of gesvd is nice, but I'd like to see the wrapper is trivially correct.
There was a problem hiding this comment.
Yes, when the wrapper is called this is ensured. How about removing lda etc from the signature and just using plain m, n?
There was a problem hiding this comment.
How about removing
ldaetc from the signature and just using plainm,n?
Done in ce32d03.
| a_gpu_usv = cupy.matmul(u_gpu[..., :k] * s_gpu[..., None, :], | ||
| vh_gpu[..., :k, :]) | ||
| else: | ||
| a_gpu_usv = cupy.matmul(u_gpu*s_gpu[..., None, :], vh_gpu) |
There was a problem hiding this comment.
I'm fine with calling matmul regardless of shape.
| # copy (via possible type casting) is done in _gesvd_batched | ||
| out = _gesvd_batched(a, a_dtype, full_matrices, compute_uv, False) | ||
| if compute_uv: | ||
| u, s, v = out |
There was a problem hiding this comment.
When I read the code first, I missed _gesvd_batched returns v instead of vt. Adding a comment to the docstring of _gesvd_batched seems sufficient for now.
|
Check again Jenkins, test this please |
|
Jenkins CI test (for commit 3515fb9, target branch master) succeeded! |
Co-authored-by: Toshiki Kataoka <[email protected]>
|
Jenkins, test this please |
|
Jenkins CI test (for commit 5a0376a, target branch master) succeeded! |
|
Jenkins, test this please |
|
Jenkins CI test (for commit 5a0376a, target branch master) failed with status FAILURE. |
|
The failed test on Jenkins is known to be flaky (#4673). |
|
Thanks, @toslunar! |
Close #3470. Part of #3062. Another shot based on #3247
to avoid manual looping.UPDATE: See the below discussions in this PR.
Prepare for adopting Array API (data-apis/array-api#114).
Code path divergence:
TODO: