-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
featureA request for a proper, new feature.A request for a proper, new feature.todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.
Description
The part of btriunpack that extract pivots has been causing some unexpected performance bottlenecks in qpth. Here's a newer version I've tried that uses gather/scatter operations across a batched vector instead of row interchanges on a batched matrix. I think it's a step towards a better method but the current form is just as slow. I want to do what the LAPACK LASWP function provides but with a batch so maybe we could use some knowledge from those implementations, like this one in OpenBLAS.
Slightly improved pivot matrix extraction but still slow version using gather/scatter
Pidx = type(LU_data)(range(sz)).repeat(nBatch, 1).long()
for i in range(sz):
k = LU_pivots[:, i] - 1
t = Pidx[:, i].clone()
Pidx[:, i] = torch.gather(Pidx, 1, k.unsqueeze(1).long())
Pidx.scatter_(1, k.unsqueeze(1).long(), t.unsqueeze(1))
P = type(LU_data)(nBatch, sz, sz).zero_()
for i in range(nBatch):
P[i].scatter_(0, Pidx[i].unsqueeze(0), 1.0)Metadata
Metadata
Assignees
Labels
featureA request for a proper, new feature.A request for a proper, new feature.todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.