-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add batched linear solver to torch.gesv() #6100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/test_torch.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/Declarations.cwrap
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ssnl could I get another review on this please? I've addressed all the previous comments. IIRC, your last pending comment was that we should delete the TH gesv and move it to ATen as well (because the batched gesv does the TH gesv in a loop). Could we leave that for a future PR? (I'll open another issue for it). It's much more difficult to implement because of the following:
I don't think it's worth the marginal benefit to deduplicate our gesv code right now, especially since users have been requesting batched gesv for a while. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Gesv.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/cuda/Gesv.cu
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Gesv.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Gesv.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor comments
torch/_torch_docs.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that.
|
@ssnl done :) sorry for the wait! |
|
@pytorchbot retest this please |
1 similar comment
|
@pytorchbot retest this please |
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Add batched linear solver to torch.gesv() Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
Fixes #3164
Picks up from #4502
I moved
gesvto ATen.Adds bindings for MAGMA's
gesv_batchedfunction for CUDA.For CPU, runs
THLapack(gesv)in a for loop.The new function supports arbitrary batch dimensions (and broadcasting
of those dimensions). For example, the 4-d tensor
A x B x M x Mshouldbe treated as having batch-size
(A x B).The overhead of creating the magma_queue_t is: ~350000 microseconds
the first time it's called and ~6 microseconds every time after that.
cc @colesbury @apaszke