-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add cuSOLVER path for torch.geqrf #56252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 5011c45 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
ghstack-source-id: 5cbd242 Pull Request resolved: pytorch#56252
Ref. #47953 [ghstack-poisoned]
ghstack-source-id: 28c2115 Pull Request resolved: pytorch#56252
xwang233
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! This overall looks good. I have left some comments.
| params, | ||
| m, | ||
| n, | ||
| CUDA_R_32F, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be beneficial to get rid of the template specialization for 64-bit API with something like this?
pytorch/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu
Lines 492 to 506 in bcdcf34
| #ifdef USE_CUSOLVER_64_BIT | |
| cusolverDnParams_t params; | |
| cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>(); | |
| TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms)); | |
| for (int64_t i = 0; i < batch_size; i++) { | |
| at::cuda::solver::xpotrs( | |
| handle, params, uplo, n, nrhs, datatype, | |
| A_ptr + i * A_matrix_stride, | |
| lda, datatype, | |
| self_working_copy_ptr + i * self_matrix_stride, | |
| ldb, | |
| infos_ptr | |
| ); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ref. #47953 [ghstack-poisoned]
Ref. #47953 [ghstack-poisoned]
Ref. #47953 [ghstack-poisoned]
ghstack-source-id: 6da75be Pull Request resolved: pytorch#56252
Ref. #47953 [ghstack-poisoned]
ghstack-source-id: ebe22ac Pull Request resolved: pytorch#56252
[ghstack-poisoned]
ghstack-source-id: 01ccae0 Pull Request resolved: pytorch#56252
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look, @xwang233!
Differential Revision: [D27960152](https://our.internmc.facebook.com/intern/diff/D27960152) [ghstack-poisoned]
ghstack-source-id: 9d9edd8 Pull Request resolved: pytorch#56252
| void geqrf_kernel(const Tensor& input, const Tensor& tau, int64_t m, int64_t n) { | ||
| #if defined(USE_CUSOLVER) | ||
| return geqrf_cusolver(input, tau, m, n); | ||
| #else | ||
| return geqrf_magma(input, tau, m, n); | ||
| #endif | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @IvanYashchuk , I forgot to ask for a benchmark table for cusolver vs magma. I see that matrices of all shapes are dispatched to cusolver path in this heuristic. Is there any cusolver performance complaint for cusolverDnXgeqrf and cusolverDn<T>geqrf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.linalg.qr with mode='r' basically just calls geqrf + triu_. Here are the results for that #56256 (comment). They show that for large sizes MAGMA is a bit faster.
Here are the results comparing geqrf_cusolver and geqrf_magma. For large sizes, MAGMA variant is faster but still we are using cuSOLVER here unconditionally, since we aim to remove all uses of single input MAGMA functions because they create and destroy cuda streams internally.
| | cuSOLVER | MAGMA |
|-------------------------------|----------|--------|
| torch.Size([2, 2]) | 0.049 | 5.3 |
| torch.Size([2, 2, 2]) | 0.034 | 10.2 |
| torch.Size([32, 2, 2]) | 0.417 | 189.5 |
| torch.Size([64, 2, 2]) | 0.840 | 321.8 |
| torch.Size([128, 2, 2]) | 1.6 | 632.9 |
| torch.Size([8, 8]) | 0.062 | 6.1 |
| torch.Size([2, 8, 8]) | 0.122 | 12.4 |
| torch.Size([32, 8, 8]) | 1.8 | 157.6 |
| torch.Size([64, 8, 8]) | 3.7 | 319.0 |
| torch.Size([128, 8, 8]) | 7.5 | 724.8 |
| torch.Size([16, 16]) | 0.125 | 6.7 |
| torch.Size([2, 16, 16]) | 0.247 | 12.7 |
| torch.Size([32, 16, 16]) | 3.9 | 152.8 |
| torch.Size([64, 16, 16]) | 7.8 | 312.1 |
| torch.Size([128, 16, 16]) | 15.6 | 661.9 |
| torch.Size([32, 32]) | 0.256 | 5.7 |
| torch.Size([2, 32, 32]) | 5.1 | 10.1 |
| torch.Size([32, 32, 32]) | 8.1 | 250.8 |
| torch.Size([64, 32, 32]) | 16.2 | 376.7 |
| torch.Size([128, 32, 32]) | 32.5 | 682.1 |
| torch.Size([64, 64]) | 0.658 | 5.7 |
| torch.Size([2, 64, 64]) | 1.3 | 9.3 |
| torch.Size([32, 64, 64]) | 20.9 | 211.8 |
| torch.Size([64, 64, 64]) | 41.8 | 312.9 |
| torch.Size([128, 64, 64]) | 83.7 | 556.3 |
| torch.Size([128, 128]) | 1.5 | 5.2 |
| torch.Size([2, 128, 128]) | 3.1 | 11.6 |
| torch.Size([32, 128, 128]) | 49.8 | 208.4 |
| torch.Size([64, 128, 128]) | 99.8 | 361.6 |
| torch.Size([128, 128, 128]) | 199.6 | 903.5 |
| torch.Size([256, 256]) | 2.3 | 9.7 |
| torch.Size([2, 256, 256]) | 4.6 | 14.7 |
| torch.Size([32, 256, 256]) | 75.9 | 228.9 |
| torch.Size([64, 256, 256]) | 152.0 | 419.8 |
| torch.Size([128, 256, 256]) | 303.9 | 846.4 |
| torch.Size([512, 512]) | 5.8 | 9.8 |
| torch.Size([2, 512, 512]) | 11.727 | 17.9 |
| torch.Size([32, 512, 512]) | 187.4 | 285.0 |
| torch.Size([64, 512, 512]) | 374.8 | 594.5 |
| torch.Size([128, 512, 512]) | 749.3 | 1263.3 |
| torch.Size([1024, 1024]) | 15.3 | 16.3 |
| torch.Size([2, 1024, 1024]) | 30.6 | 32.7 |
| torch.Size([32, 1024, 1024]) | 490.8 | 527.6 |
| torch.Size([64, 1024, 1024]) | 985.4 | 1022.6 |
| torch.Size([128, 1024, 1024]) | 1978.6 | 2026.9 |
| | | |
| torch.Size([512, 512]) | 8.0 | 11.9 |
| torch.Size([1024, 1024]) | 15.1 | 22.5 |
| torch.Size([2048, 2048]) | 54.9 | 54.9 |
| torch.Size([4096, 4096]) | 276.4 | 265.8 |
| torch.Size([8192, 8192]) | 1712.3 | 1555.8 |
Times are in milliseconds (ms).
Summary: Pull Request resolved: pytorch#56252 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27960152 Pulled By: mruberry fbshipit-source-id: 0510a302aab50623d7490efaba0133f740cd57c3
Summary: Pull Request resolved: pytorch#56252 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27960152 Pulled By: mruberry fbshipit-source-id: 0510a302aab50623d7490efaba0133f740cd57c3

Stack from ghstack:
Differential Revision: D27960152