Skip to content

Improve the performance of btriunpack #1791

@bamos

Description

@bamos

The part of btriunpack that extract pivots has been causing some unexpected performance bottlenecks in qpth. Here's a newer version I've tried that uses gather/scatter operations across a batched vector instead of row interchanges on a batched matrix. I think it's a step towards a better method but the current form is just as slow. I want to do what the LAPACK LASWP function provides but with a batch so maybe we could use some knowledge from those implementations, like this one in OpenBLAS.

Slightly improved pivot matrix extraction but still slow version using gather/scatter

Pidx = type(LU_data)(range(sz)).repeat(nBatch, 1).long()

for i in range(sz):
    k = LU_pivots[:, i] - 1
    t = Pidx[:, i].clone()
    Pidx[:, i] = torch.gather(Pidx, 1, k.unsqueeze(1).long())
    Pidx.scatter_(1, k.unsqueeze(1).long(), t.unsqueeze(1))

P = type(LU_data)(nBatch, sz, sz).zero_()
for i in range(nBatch):
    P[i].scatter_(0, Pidx[i].unsqueeze(0), 1.0)

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.todoNot as important as medium or high priority tasks, but we will work on these.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions