-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Here is code comparing the CUDA and non-CUDA behavior:
m1 = torch.randn(2,5)
m2 = m1[:, 4]
x = torch.sparse.DoubleTensor(torch.LongTensor([[0,1]]), torch.DoubleTensor([33, 44]))
m2[0] = 0
m2[1] = 0
m2.add_(x)
print(x)
print(m1)
print(m2)
print("----")
m1 = torch.randn(2,5).cuda()
m2 = m1[:, 4]
x = torch.sparse.DoubleTensor(torch.LongTensor([[0,1]]), torch.DoubleTensor([33, 44])).cuda()
m2[0] = 0
m2[1] = 0
m2.add_(x)
print(x)
print(m1)
print(m2)
DoubleTensor of size 2 with indices:
0 1
[torch.LongTensor of size 1x2]
and values:
33
44
[torch.DoubleTensor of size 2]
-0.5214 -1.4914 -0.2381 1.0306 33.0000
1.5162 -1.5116 -0.5050 -0.2216 44.0000
[torch.DoubleTensor of size 2x5]
33
44
[torch.DoubleTensor of size 2]
----
DoubleTensor of size 2 with indices:
0 1
[torch.cuda.LongTensor of size 1x2 (GPU 0)]
and values:
33
44
[torch.cuda.DoubleTensor of size 2 (GPU 0)]
-1.5411 0.8085 1.0213 -0.1240 0.0000
-1.2078 -0.5452 0.0656 -0.7886 0.0000
[torch.cuda.DoubleTensor of size 2x5 (GPU 0)]
0
0
[torch.cuda.DoubleTensor of size 2 (GPU 0)]
In the CUDA case no update actually occurred.
Looking at the CUDA code it's pretty clear what the problem is:
THCTensor *r = r_;
if (r != dense) {
THCTensor_(retain)(state, r);
THCTensor_(resizeAs)(state, r, dense);
THCTensor_(copy)(state, r, dense);
} else {
r = THCTensor_(newContiguous)(state, r_);
}
The call to newContiguous will create a clone of r_ if it was not contiguous. This means that updates to r will not be reflected in r_.
CC @martinraison who originally added spcadd, and @adamlerer who sped it up.
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module