Skip to content

Conversation

@nairbv
Copy link
Collaborator

@nairbv nairbv commented May 13, 2019

@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: tests Issues related to tests (not the torch.testing module) labels May 13, 2019
Gets TestJitGeneratedAutograd.test_to_sparse to pass but need feedbak
from jit team on test_sparse_tensors_error.
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: pybind Related to our Python bindings / interactions with other Python libraries labels May 14, 2019
@nairbv nairbv requested a review from eellison May 14, 2019 21:15
@nairbv
Copy link
Collaborator Author

nairbv commented May 14, 2019

As discussed with @eellison, need feedback from jit team on support for sparse tensors.

@nairbv nairbv changed the title initial not-working attempt to add autograd for to_sparse Add autograd for to_sparse. May 14, 2019
@nairbv
Copy link
Collaborator Author

nairbv commented May 15, 2019

@pytorchbot rebase this please

@eellison
Copy link
Contributor

I don't think we want to add the additional complexity of Sparse tensors to JIT yet.

We have an existing pass inpace_check.cpp that checks for one particular unsupported op before running the JIT. Maybe you could rename the pass and check for aten::to_sparse in that pass as well?

How does that sound @suo @zdevito

@eellison eellison requested review from apaszke and zdevito May 16, 2019 18:07
@nairbv
Copy link
Collaborator Author

nairbv commented May 21, 2019

@pytorchbot rebase this please

@nairbv nairbv requested a review from gchanan June 5, 2019 14:31
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic looks fine, some comments though

- name: sparse_mask(Tensor self, Tensor mask)
self: not_implemented("sparse_mask")
mask: not_implemented("sparse_mask")
self: grad.to_dense().sparse_mask(mask).to_dense()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary to implement to_sparse backward?
Also, this (sparse_mask backward) needs a test given that you've implemented it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in ./aten/src/ATen/native/TensorConversions.cpp you can see that to_dense_backward calls sparse_mask, which ends up causing this to be required.

Copy link
Contributor

@gchanan gchanan Jun 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to densify the grad? I guess we don't have a kernel doing what you want, but it seems feasible, though maybe not worth the effort.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_mask requires a dense tensor and sparse mask. if I don't call to_dense() on the grad, iirc, sparse mask rejects it. If I don't call to_dense on the result, IIRC I got an error that the gradient layout has to match. I guess we could write an explicit kernel that took a sparse gradient and returned a dense mask though. I can add one, or at least file a ticket if you think it's worthwhile?

def test_sparse_gather_both_scalar(self):
self._test_sparse_gather((), (), 0)

def test_to_sparse_autograd(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this function necessary? There's already a to_sparse test in common_method_invocations that autograd should call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not, I added this first before realizing I should add in common method invocs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to add a sparse mask test in common invocs because it's not setup to handle tests with sparse inputs, so I converted this test to test sparse_mask.

does sparse mask also gets tested indirectly via the common method invocations entry for to_sparse?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradcheck and gradgradcheck are run via the common method invocations, so yes sparse_mask is indirectly tested (in the gradgradcheck step), but it's always good to have an explicit test

def test_sparse_mask_autograd(self):
for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
tensor = torch.randn(3, requires_grad=True, device=device)
mask = torch.ones(3, device=device).to_sparse()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should test a mask that is not just all ones, something like torch.tensor([1, 0, 1])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, that was a bit lazy of me. fixed.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

mask = mask.to_sparse()
converted = tensor.sparse_mask(mask).to_dense()
converted.sum().backward()
self.assertEqual(tensor.grad, mask.to_dense())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does gradcheck / gradgradcheck not work here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @nairbv said gradcheck / gradgradcheck can't handle sparse tensors (as autograd leafs).

But it could be good to run gradcheck on tensor.sparse_mask(mask.to_sparse()) (where tensor, mask are dense tensors, if that works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya, I think we've run into this before and it can be worked around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_mask didn't allow for a dense mask

@facebook-github-bot
Copy link
Contributor

@nairbv merged this pull request in 8a9ea55.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants