-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
in index_put_ operation indices backend now has to match source backend. It used to be not necessary, so that e.g. cuda tensor could be indexed by cpu tensor. Blame points to #17991 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Indexing.cpp#L204
To Reproduce
import torch
a=torch.arange(30, dtype=torch.float).view(5,6).cuda()
ind0 = torch.arange(0,a.size(0), step=2)
gO = torch.randn(a[ind0].size()).cuda()
a.index_put_((ind0,), gO, accumulate=True)
torch.cuda.synchronize()
used to work, now it does not. Note that forward operation with ind0 on the cpu still works (a[ind0]), but index_put_ breaks
Expected behavior
Either clarification in the docs and in the code that indices tensor has to have the same backend as source tensor, or fixing the RuntimeError. In case we want to disable indexing with a different backend, it would make sense to disable it for the forward indexing operation too, right now forward would be fine, but backward will through a runtime error.
Environment
Pytorch 1.1 binary and recent source builds
cc @colesbury