-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add cuSOLVER path for torch.linalg.lstsq #57317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit b105517 (more details on the Dr. CI page):
2 failures not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
|
@xwang233 and @lezcano and/or @nikitaved, would you review this, please? |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! The logic is as clean as it can be. I just left a small comment on a bit that I found slightly more difficult to understand.
| :attr:`driver` chooses the LAPACK/MAGMA function that will be used. | ||
| For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. | ||
| For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank and `m < n`. | ||
| For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| const_cast<Tensor&>(infos), | ||
| upper, transpose, conjugate_transpose, unitriangular); | ||
|
|
||
| B.narrow(-2, m, n - m).zero_(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because triangular_solve_kernel writes its output into the first m elements of B, right? Could you leave a comment explaining this here?
xwang233
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the PR!
test/test_linalg.py
Outdated
| # cases m < n are only supported on CPU and for cuSOLVER path on CUDA | ||
| m_l_n_sizes = [(m // 2, m) for m in ms] | ||
| matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if device == 'cpu' else []) | ||
| matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if cusolver_available else []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use (cusolver_available or device == 'cpu') to test both?
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. ghstack-source-id: e7d2246 Pull Request resolved: pytorch#57317
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work all! Thanks for reviewing, @nikitaved, @xwang233
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
1 similar comment
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Reverting this PR because it broke one of the Windows test jobs: https://app.circleci.com/pipelines/github/pytorch/pytorch/317376/workflows/463399f8-78ef-4894-a9bf-8b666943efc2/jobs/13217419 |
|
This pull request has been reverted by 72ebdd6. |
|
This diff was revert, but the previous commits in the stack were not, I think. Link to why it was reverted: It broke pytorch_windows_vs2019_py36_cuda10.1_test2 and tests test_linalg_lstsq_input_checks_cuda_complex128, test_linalg_lstsq_input_checks_cuda_complex64, test_linalg_lstsq_input_checks_cuda_float32, and test_linalg_lstsq_input_checks_cuda_float64. Sample failure snippet: The easiest way to reland the rest of the stack is probably to rebase the uncommitted PRs on nightly with the fix. We can run the updated PR through ci/all to validate this build is fixed, too. |
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Differential Revision: [D28242069](https://our.internmc.facebook.com/intern/diff/D28242069) [ghstack-poisoned]
|
@mruberry, I fixed the problem with that Windows CUDA 10.1 build. Here is the ci-all PR #57816. The problem was that the condition of cuSOLVER availability was not correct in the test. I think we should consider adding a more robust way to check from Python whether cuSOLVER is used in PyTorch. We use cuSOLVER if CUDA version is >= 10.1.243, but |
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242069 Pulled By: mruberry fbshipit-source-id: 23979d19ccc7f591afa8df4435d0db847e2d0d97
Thanks @IvanYashchuk, and thanks for the thorough analysis. So users with a CUDA version between 10.1 and 10.1.243 will get the correct behavior (we think), but our test suite will report the behavior as incorrect? |
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Yes, but our test suite doesn't test the behavior for these versions, the tests will pass. |
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242069 Pulled By: mruberry fbshipit-source-id: 23979d19ccc7f591afa8df4435d0db847e2d0d97
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28312683 Pulled By: mruberry fbshipit-source-id: dc8ae837a5fb0685d85c8733a47d7d25dc46443a
Stack from ghstack:
This PR implements QR-based least squares solver using geqrf, ormqr, and
triangular_solve operations.
Internal code of triangular_solve was fixed to handle correctly larger
sized rectangular arrays.
Differential Revision: D28312683