-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement batch matrix inverse #9102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
A general |
|
@vadimkantorov isn't that what |
|
@karol-arndt You are right! Thanks for the tip :) This makes porting that fbpca code trivial to PyTorch. |
|
Thanks for the PR! |
|
Well, this indeed sounds like a natural and intuitive extension of the current |
|
I think the trend will probably be to move away from the But let's see what the others think about it. |
|
I agree with @fmassa -- I think we're trying to move away from the 'b' prefix to simplify the API, and I would prefer a batch inverse function be built into the inverse function.
|
|
How hard would it be to implement this in ATen versus TH? (Asking for information.) |
|
@fmassa In that case, I'll adopt the code to extend the @ezyang I've never written any ATen code before. I wouldn't expect it to be particularly difficult though, the API seems to be more friendly than TH. If the current trend is to write new code in ATen, I can try to port the implementation in the coming days. Should I use MKL for the CPU code? |
|
Yep, we're trying to do everything in ATen as much as possible. Sometimes it's not possible, but when it is, it's preferred. If MKL has got a good CPU implementation, it's definitely a good pick. |
|
@ezyang I implemented the inverse methods in ATen. I had to add a cuBLAS handle to ATen context (it's shared with the THC handle, similarly to the cuSPARSE handle). Worth noting - in addition to cuBLAS version, THC also had a Magma-based implementation of inverse, which I didn't reimplement when porting to ATen. Is it still needed and should it be reimplemented? @fmassa As you asked, the code now works with arbitrary number of batch dimensions as an extension of the previous |
|
@pytorchbot retest this please |
aten/src/ATen/native/cuda/Inverse.cu
Outdated
| scalar_t **output_gpu; | ||
| scalar_t **input_ptrs = new scalar_t*[batch_size]; | ||
| scalar_t **output_ptrs = new scalar_t*[batch_size]; | ||
| AT_CUDA_CHECK(cudaMalloc(&input_gpu, batch_size*sizeof(scalar_t*))); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@karol-arndt the MAGMA based inverse is MUCH faster than the cublas version for many sizes and overall has a better performance profile (as can be attested to by @martinarjovsky who's been reporting this). If it's not too much of an ask, porting the magma bindings would be good as well. |
|
Well, this PR definitely needs more work anyway, as some of the tests are currently failing (the results appear to be transposes of the correct ones, which seems like a data alignment issue on some platforms). I currently don't have the time to work on this anymore - and the current implementation is good enough for me to continue my research. I will most likely return to this in a few weeks, but if someone has the time and energy to fix the issue and add the MAGMA version, that would be great ;) |
|
I believe there is a continue work of this PR at #9949, Thanks @karol-arndt for implementing this feature and moved it into ATen! |
|
I think this can be closed now, since batch inverse is now part of master. |
|
Thanks @karol-arndt for the original implementation! |
I was recently working on some Kalman filter stuff and found myself in need of a batch matrix inverse, so I implemented it (doing it in a Python for loop is incredibly slow, especially with CUDA). Since cuBLAS has such functionality already implemented (and the standard inverse function just passes 1 as the batch size), it's just a matter of allocating some buffers and passing data to the appropriate cuBLAS functions. The implementation is based on the
btrifact(batch LU factorization usinggetrf) function. I also added a CPU implementation (which is really just a for loop) for the sake of completeness.I figured that this might be useful for other people, so I'm sharing it here. This is my first contribution to PyTorch and I'm not very experienced with CUDA programming, so all comments regarding the code are most welcome.