-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add sparse gradient option to gather operation
#17182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e861a36 to
31d33a8
Compare
|
CI is failing (maybe you need a rebase?) |
gchanan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith what's our current story with torch.sparse? If the input tensors can all be dense but we want the gradient to be sparse, we don't put it in torch.sparse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced getting rid of this is safe, but unfortunately the error message doesn't mention why it's not safe...
Let me dig into this a little bit more.
|
The test failures I saw (test_alexnet) can be fixed by rebasing or merging in master. |
31d33a8 to
fa30f47
Compare
This PR follows embedding, where also all the inputs are dense, and gradients sparsity is controlled by a keyword argument, and embedding is not in torch.sparse. |
|
@ngimel yes I agree this follows embedding, was just checking if that is current recommended way of doing things. Also, I looked at the error message you removed and I think you can just get rid of it. Note that there appears to be a similar problem with we can fix that in another commit though. |
|
Can we please call it |
| return grad.sparse_mask(at::SparseTensorRef(input)); | ||
| } | ||
|
|
||
| Tensor gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this to a native function? (You can call it _gather_sparse_backward). The issue with the above is because it's not a native function our backend extensions can't override it.
test/test_autograd.py
Outdated
| if len(size_x) > 0: | ||
| x = torch.ones(*size_x, requires_grad=True) | ||
| else: | ||
| x = torch.randn((), requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you always pass size_x instead of having the if/else, i.e.
x = torch.randn(size_x, requires_grad=True)
?
test/test_autograd.py
Outdated
| x = torch.randn((), requires_grad=True) | ||
| if len(size_ind) > 0 and len(size_x) > 0: | ||
| ind = torch.randint(x.size(dim), size_ind) | ||
| elif len(size_ind) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we replace the elif/else here with:
ind = torch.zeros(size_ind, dtype=torch.int64)
?
|
|
||
| Tensor gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ | ||
| // special case scalar input and/or index | ||
| if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,1}, index.options()), grad, self.sizes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks wrong because we in general allow scalar and 1-dimensional inputs to inter-operate. This is assuming that self.dim() == 0 implies grad.dim() == 0, but that's unfortunately not the case.
See this example:
>>> self = torch.randn((), requires_grad=True)
>>> torch.gather(self, 0, torch.zeros(2).long(), sparse=True).sum().backward()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good catch, but what should the correct behavior be? The sparse gradient would be 1D (there's no way to make it scalar, because it has several elements). Is it ok to have 1D sparse gradient for scalar input? One way around would be to call .coalesce() and possibly reshape right away, that way scalar gradient may be preserved, but then, we are not calling .coalesce() for non-degenerate cases, and simply return sparse gradients with repeating indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, why would the sparse gradient need to be 1-D? In the specific case, is it not correct to change:
at::_sparse_coo_tensor_unsafe(at::empty({0,1}, index.options()), grad, self.sizes());
to:
at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Did not know it works like this, I though indices.numel() and values.numel() should be the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually have to look up the invariants each time to double-check, here is where I look:
pytorch/aten/src/ATen/SparseTensorImpl.h
Lines 11 to 15 in c3a2337
| // INVARIANTS: | |
| // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape) | |
| // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape) | |
| // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz) | |
| // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, shape[sparse_dim:]) |
|
Added previously failing case |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@gchanan Anything I need to do on this PR? Is it failing internal tests? |
|
@ngimel internal tests look good, let me try to merge it. |
Summary:
This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function
```
def label_smoothing_opt(x, target):
padding_idx = 0
smoothing = 0.1
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
pad_mask = (target == padding_idx)
ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
smooth_loss = logprobs.mean(dim=-1)
loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
loss.masked_fill_(pad_mask, 0)
return loss.sum()
```
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!
cc gchanan apaszke
Pull Request resolved: pytorch/pytorch#17182
Differential Revision: D14158431
Pulled By: gchanan
fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade
This PR allows
gatherto optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.I've commented out size.size() check in
SparseTensor.cppthat also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.Motivating example:
For this commonly used label smoothing loss function
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!
cc @gchanan @apaszke