Skip to content

Direct inversion and linear systems solutions for small matrices #63992

@o-alexandre-felipe

Description

@o-alexandre-felipe

🚀 Feature

Performance improvements to linalg in GPUs

Motivation

After noticing that one of my functions was running unusually slow and timing line per line I noticed that a batched inversion of 2 x 2 matrices was hundreds of time slower than a 8k-point FFT on GPU. Experiments show that the a 2 x 2 inversion formula implemented in python may be hundreds of time faster on GPU than the linal implementatinos for inv and solve, of pytorch 1.9.0.

Pitch

The linalg module is an active development area in pytorch, it may have many opportunities for improvements, and it seems I identified one important.

Alternatives

Probably implementing this feature in a wrapper as I did is not a good idea since it would probably affect the performance for small batches.

Additional context

I prepared a notebook that can be visulised here, where I describe the implementation with the analytic inversion formula and run the experiments, so that you can review my methodology.

Just to give an ide, here I show the performance of a batched solve, in number of 2x2 matrix inversions per microseconds.

computation float32 float64 complex64 complex128
custom@cuda 3533.943528 4137.873960 3723.701359 3645.776368
linalg@cuda 22.632545 22.065084 22.591386 16.020514
custom@cpu 117.163244 47.921409 43.829660 12.162328
linalg@cpu 9.671292 7.813403 7.009462 6.682558

The proposed implementation is not very appealing for CPU but it gives massive gains for the GPU backend. But, the only scenario where the proposed implementation is considerably worse than the linalg implementation was the CPU with back propagation for complex128, I don't believe this is a very important case, since in general when training GPUs are available.

The proposed method gains with a significant margin, this suggests that there may be gains for larger systems as well, but I did not test.

cc @VitalyFedyunin @ngimel @heitorschueroff @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions