-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
When x is a tensor of shape (N, D) and idx is a tensor of indices of shape K, the backward pass of x[idx] is much slower than the equivalent operation implemented using gather. Here is a benchmarking script:
On a P100 GPU with PyTorch 1.0 stable, across a variety of problem shapes I get the following results:
Forward gather speedup:
Min: 0.7549055905220289
Max: 5.590410529614541
Mean: 0.9328673787035276
Median: 0.880012936610608
Backward gather speedup:
Min: 1.6313537996980372
Max: 23.95120218579235
Mean: 3.340551245050125
Median: 1.8802246977054176
Basically this says that on the forward pass index is sometimes faster and gather is sometimes faster. However on the backward pass, gather is always faster than integer indexing.
This is surprising, and suggests that although the two operations perform the same computation their implementations have very different performance characteristics. Integer indexing is much more intuitive than gather, so I suspect that many users are unknowingly leaving a lot of performance on the table by choosing integer indexing over gather. In one of my own applications, replacing integer indexing with gather resulted in a more than 2x speedup on my overall training iteration times!
Would it be possible to somehow unify the implementation of the two operations, or otherwise ensure that integer indexing always performs at least as well as gather?