-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Unnecessary cuda synchronizations that we should remove in PyTorch #108968
Description
🚀 The feature, motivation and pitch
There are a number of unnecessary cuda synchronizations in PyTorch ops, and I think we should endeavor to remove them whenever possible.
To check syncs, you can use torch.cuda.set_sync_debug_mode("warn")
I'm creating this issue to track ones that I've seen/found.
- torch.multinomial with
num_samples=1. For this I think we should simply remove the error check causing the sync, and ideally turn it into a cuda async error. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Distributions.cpp#L615
A = torch.rand(10)
torch.multinomial(A, num_samples=1)
- repeat_interleave with a tensor number of repeats encourages synchronization. We cannot use
repeatswith a non-cuda tensor, and that forces a synchronization. For this I think we should add a list of ints overload or allow passing a CPU tensor for repeats.
A = torch.randn(3, device='cuda')
num_repeats = torch.tensor([2, 3, 5])
out = torch.repeat_interleave(A, num_repeats.cuda(), dim=0)
-
Indexing with a scalar tensor performs a synchronization. See Turn indexing with a scalar tensor into an copy into a view and avoid a D2H synchronization. #105641 for more details.
-
torch.normalalso incurs a sync on std: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/DistributionTemplates.h#L222 -
nanmedianincurs a sync: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Sorting.cpp#L149 -
prod_backward: torch.prod cannot be used with cudagraphs #128396
Alternatives
No response
Additional context
No response
cc @ptrblck