-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add autograd for to_sparse. #20458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add autograd for to_sparse. #20458
Conversation
Gets TestJitGeneratedAutograd.test_to_sparse to pass but need feedbak from jit team on test_sparse_tensors_error.
|
As discussed with @eellison, need feedback from jit team on support for sparse tensors. |
|
@pytorchbot rebase this please |
|
I don't think we want to add the additional complexity of Sparse tensors to JIT yet. We have an existing pass |
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic looks fine, some comments though
| - name: sparse_mask(Tensor self, Tensor mask) | ||
| self: not_implemented("sparse_mask") | ||
| mask: not_implemented("sparse_mask") | ||
| self: grad.to_dense().sparse_mask(mask).to_dense() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary to implement to_sparse backward?
Also, this (sparse_mask backward) needs a test given that you've implemented it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in ./aten/src/ATen/native/TensorConversions.cpp you can see that to_dense_backward calls sparse_mask, which ends up causing this to be required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to densify the grad? I guess we don't have a kernel doing what you want, but it seems feasible, though maybe not worth the effort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse_mask requires a dense tensor and sparse mask. if I don't call to_dense() on the grad, iirc, sparse mask rejects it. If I don't call to_dense on the result, IIRC I got an error that the gradient layout has to match. I guess we could write an explicit kernel that took a sparse gradient and returned a dense mask though. I can add one, or at least file a ticket if you think it's worthwhile?
test/test_autograd.py
Outdated
| def test_sparse_gather_both_scalar(self): | ||
| self._test_sparse_gather((), (), 0) | ||
|
|
||
| def test_to_sparse_autograd(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this function necessary? There's already a to_sparse test in common_method_invocations that autograd should call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably not, I added this first before realizing I should add in common method invocs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to add a sparse mask test in common invocs because it's not setup to handle tests with sparse inputs, so I converted this test to test sparse_mask.
does sparse mask also gets tested indirectly via the common method invocations entry for to_sparse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gradcheck and gradgradcheck are run via the common method invocations, so yes sparse_mask is indirectly tested (in the gradgradcheck step), but it's always good to have an explicit test
test/test_autograd.py
Outdated
| def test_sparse_mask_autograd(self): | ||
| for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: | ||
| tensor = torch.randn(3, requires_grad=True, device=device) | ||
| mask = torch.ones(3, device=device).to_sparse() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should test a mask that is not just all ones, something like torch.tensor([1, 0, 1])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, that was a bit lazy of me. fixed.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| mask = mask.to_sparse() | ||
| converted = tensor.sparse_mask(mask).to_dense() | ||
| converted.sum().backward() | ||
| self.assertEqual(tensor.grad, mask.to_dense()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does gradcheck / gradgradcheck not work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @nairbv said gradcheck / gradgradcheck can't handle sparse tensors (as autograd leafs).
But it could be good to run gradcheck on tensor.sparse_mask(mask.to_sparse()) (where tensor, mask are dense tensors, if that works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya, I think we've run into this before and it can be worked around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse_mask didn't allow for a dense mask
#18111