-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
Performance improvements to linalg in GPUs
Motivation
After noticing that one of my functions was running unusually slow and timing line per line I noticed that a batched inversion of 2 x 2 matrices was hundreds of time slower than a 8k-point FFT on GPU. Experiments show that the a 2 x 2 inversion formula implemented in python may be hundreds of time faster on GPU than the linal implementatinos for inv and solve, of pytorch 1.9.0.
Pitch
The linalg module is an active development area in pytorch, it may have many opportunities for improvements, and it seems I identified one important.
Alternatives
Probably implementing this feature in a wrapper as I did is not a good idea since it would probably affect the performance for small batches.
Additional context
I prepared a notebook that can be visulised here, where I describe the implementation with the analytic inversion formula and run the experiments, so that you can review my methodology.
Just to give an ide, here I show the performance of a batched solve, in number of 2x2 matrix inversions per microseconds.
| computation | float32 | float64 | complex64 | complex128 |
|---|---|---|---|---|
| custom@cuda | 3533.943528 | 4137.873960 | 3723.701359 | 3645.776368 |
| linalg@cuda | 22.632545 | 22.065084 | 22.591386 | 16.020514 |
| custom@cpu | 117.163244 | 47.921409 | 43.829660 | 12.162328 |
| linalg@cpu | 9.671292 | 7.813403 | 7.009462 | 6.682558 |
The proposed implementation is not very appealing for CPU but it gives massive gains for the GPU backend. But, the only scenario where the proposed implementation is considerably worse than the linalg implementation was the CPU with back propagation for complex128, I don't believe this is a very important case, since in general when training GPUs are available.
The proposed method gains with a significant margin, this suggests that there may be gains for larger systems as well, but I did not test.
cc @VitalyFedyunin @ngimel @heitorschueroff @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano