Skip to content

Conversation

@karol-arndt
Copy link

@karol-arndt karol-arndt commented Jul 2, 2018

I was recently working on some Kalman filter stuff and found myself in need of a batch matrix inverse, so I implemented it (doing it in a Python for loop is incredibly slow, especially with CUDA). Since cuBLAS has such functionality already implemented (and the standard inverse function just passes 1 as the batch size), it's just a matter of allocating some buffers and passing data to the appropriate cuBLAS functions. The implementation is based on the btrifact (batch LU factorization using getrf) function. I also added a CPU implementation (which is really just a for loop) for the sake of completeness.

I figured that this might be useful for other people, so I'm sharing it here. This is my first contribution to PyTorch and I'm not very experienced with CUDA programming, so all comments regarding the code are most welcome.

@vadimkantorov
Copy link
Contributor

A general getrf wrapper may also be useful, e.g. for porting Randomized PCA functions (#8049)...

@karol-arndt
Copy link
Author

@vadimkantorov isn't that what btrifact does? Looking at the code, it only wraps the call to getrf with some code to ensure column-major, manage the device buffers and check the error codes.

@vadimkantorov
Copy link
Contributor

@karol-arndt You are right! Thanks for the tip :) This makes porting that fbpca code trivial to PyTorch.

@fmassa
Copy link
Member

fmassa commented Jul 2, 2018

Thanks for the PR!
I have a quick suggestion on the python interface: instead of making a new binverse function, can't we instead extend support for inverse to support arbitrary batch dimensions (where we view all the batch dimensions as 1 and then view back after the result). What do you think?

@karol-arndt
Copy link
Author

Well, this indeed sounds like a natural and intuitive extension of the current inverse method, and wouldn't cause any extra clutter in the Python interface. It's also how the API of another popular tensor library works.
On the other hand, currently most - if not all - functions that operate on batch data have a name prefixed with a b (bmm, btrifact, btrisolve...), and extending inverse to work both with batches and single n-by-n matrices would break that convention.
Overall, it's hard to say which is preferable; perhaps some feedback from the code owners could help to solve this problem.

@fmassa
Copy link
Member

fmassa commented Jul 3, 2018

I think the trend will probably be to move away from the b prefix (which is in some sense legacy), in favor of functions that handle batches (like matmul is the successor of mm and bmm).

But let's see what the others think about it.

@zou3519
Copy link
Contributor

zou3519 commented Jul 3, 2018

I agree with @fmassa -- I think we're trying to move away from the 'b' prefix to simplify the API, and I would prefer a batch inverse function be built into the inverse function.

bmm in particular is "deprecated" in favor of torch.matmul, and the reason why btrifact, btrisolve are named such is because they're direct bindings to the lapack functions of the same name.

@ezyang
Copy link
Contributor

ezyang commented Jul 3, 2018

How hard would it be to implement this in ATen versus TH? (Asking for information.)

@karol-arndt
Copy link
Author

@fmassa In that case, I'll adopt the code to extend the inverse function.

@ezyang I've never written any ATen code before. I wouldn't expect it to be particularly difficult though, the API seems to be more friendly than TH. If the current trend is to write new code in ATen, I can try to port the implementation in the coming days. Should I use MKL for the CPU code?

@ezyang
Copy link
Contributor

ezyang commented Jul 4, 2018

Yep, we're trying to do everything in ATen as much as possible. Sometimes it's not possible, but when it is, it's preferred.

If MKL has got a good CPU implementation, it's definitely a good pick.

@karol-arndt
Copy link
Author

@ezyang I implemented the inverse methods in ATen. I had to add a cuBLAS handle to ATen context (it's shared with the THC handle, similarly to the cuSPARSE handle).

Worth noting - in addition to cuBLAS version, THC also had a Magma-based implementation of inverse, which I didn't reimplement when porting to ATen. Is it still needed and should it be reimplemented?

@fmassa As you asked, the code now works with arbitrary number of batch dimensions as an extension of the previous inverse method.

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2018

@pytorchbot retest this please

scalar_t **output_gpu;
scalar_t **input_ptrs = new scalar_t*[batch_size];
scalar_t **output_ptrs = new scalar_t*[batch_size];
AT_CUDA_CHECK(cudaMalloc(&input_gpu, batch_size*sizeof(scalar_t*)));

This comment was marked as off-topic.

@vadimkantorov vadimkantorov mentioned this pull request Jul 9, 2018
@soumith
Copy link
Contributor

soumith commented Jul 13, 2018

@karol-arndt the MAGMA based inverse is MUCH faster than the cublas version for many sizes and overall has a better performance profile (as can be attested to by @martinarjovsky who's been reporting this). If it's not too much of an ask, porting the magma bindings would be good as well.

@karol-arndt
Copy link
Author

Well, this PR definitely needs more work anyway, as some of the tests are currently failing (the results appear to be transposes of the correct ones, which seems like a data alignment issue on some platforms). I currently don't have the time to work on this anymore - and the current implementation is good enough for me to continue my research. I will most likely return to this in a few weeks, but if someone has the time and energy to fix the issue and add the MAGMA version, that would be great ;)

@ailzhang ailzhang assigned ailzhang and unassigned ailzhang Jul 24, 2018
@vishwakftw vishwakftw mentioned this pull request Jul 28, 2018
12 tasks
@weiyangfb
Copy link
Contributor

I believe there is a continue work of this PR at #9949, Thanks @karol-arndt for implementing this feature and moved it into ATen!

@vishwakftw
Copy link
Contributor

I think this can be closed now, since batch inverse is now part of master.

@fmassa
Copy link
Member

fmassa commented Oct 28, 2018

Thanks @karol-arndt for the original implementation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants