Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Feb 15, 2019

This PR allows gather to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in SparseTensor.cpp that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function

def label_smoothing_opt(x, target):
    padding_idx = 0
    smoothing = 0.1
    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
    pad_mask = (target == padding_idx)
    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
    smooth_loss = logprobs.mean(dim=-1)
    loss =  (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
    loss.masked_fill_(pad_mask, 0)
    return loss.sum()

backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!

cc @gchanan @apaszke

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Feb 15, 2019
@ngimel ngimel force-pushed the sparse_gather branch 2 times, most recently from e861a36 to 31d33a8 Compare February 16, 2019 00:24
@soumith
Copy link
Contributor

soumith commented Feb 19, 2019

CI is failing (maybe you need a rebase?)

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith what's our current story with torch.sparse? If the input tensors can all be dense but we want the gradient to be sparse, we don't put it in torch.sparse?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced getting rid of this is safe, but unfortunately the error message doesn't mention why it's not safe...

Let me dig into this a little bit more.

@gchanan
Copy link
Contributor

gchanan commented Feb 19, 2019

The test failures I saw (test_alexnet) can be fixed by rebasing or merging in master.

@ngimel
Copy link
Collaborator Author

ngimel commented Feb 19, 2019

@soumith what's our current story with torch.sparse? If the input tensors can all be dense but we want the gradient to be sparse, we don't put it in torch.sparse?

This PR follows embedding, where also all the inputs are dense, and gradients sparsity is controlled by a keyword argument, and embedding is not in torch.sparse.

@gchanan
Copy link
Contributor

gchanan commented Feb 19, 2019

@ngimel yes I agree this follows embedding, was just checking if that is current recommended way of doing things.

Also, I looked at the error message you removed and I think you can just get rid of it. Note that there appears to be a similar problem with to_sparse. I.e. even if you get rid of the error message above, the following still fails:

In [12]: torch.ones(()).to_sparse()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-12-8d76f7b6cb2b> in <module>()
----> 1 torch.ones(()).to_sparse()

RuntimeError: sparse_dim must be >0

we can fix that in another commit though.

@apaszke
Copy link
Contributor

apaszke commented Feb 19, 2019

Can we please call it sparse_grad? sparse=True looks as if it was to return a sparse tensor.

return grad.sparse_mask(at::SparseTensorRef(input));
}

Tensor gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this to a native function? (You can call it _gather_sparse_backward). The issue with the above is because it's not a native function our backend extensions can't override it.

if len(size_x) > 0:
x = torch.ones(*size_x, requires_grad=True)
else:
x = torch.randn((), requires_grad=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you always pass size_x instead of having the if/else, i.e.
x = torch.randn(size_x, requires_grad=True)
?

x = torch.randn((), requires_grad=True)
if len(size_ind) > 0 and len(size_x) > 0:
ind = torch.randint(x.size(dim), size_ind)
elif len(size_ind) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we replace the elif/else here with:
ind = torch.zeros(size_ind, dtype=torch.int64)
?


Tensor gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
// special case scalar input and/or index
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,1}, index.options()), grad, self.sizes());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks wrong because we in general allow scalar and 1-dimensional inputs to inter-operate. This is assuming that self.dim() == 0 implies grad.dim() == 0, but that's unfortunately not the case.

See this example:

>>> self = torch.randn((), requires_grad=True)
>>> torch.gather(self, 0, torch.zeros(2).long(), sparse=True).sum().backward()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch, but what should the correct behavior be? The sparse gradient would be 1D (there's no way to make it scalar, because it has several elements). Is it ok to have 1D sparse gradient for scalar input? One way around would be to call .coalesce() and possibly reshape right away, that way scalar gradient may be preserved, but then, we are not calling .coalesce() for non-degenerate cases, and simply return sparse gradients with repeating indices.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, why would the sparse gradient need to be 1-D? In the specific case, is it not correct to change:

at::_sparse_coo_tensor_unsafe(at::empty({0,1}, index.options()), grad, self.sizes());

to:

at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Did not know it works like this, I though indices.numel() and values.numel() should be the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually have to look up the invariants each time to double-check, here is where I look:

// INVARIANTS:
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, shape[sparse_dim:])

@ngimel
Copy link
Collaborator Author

ngimel commented Feb 20, 2019

Added previously failing case
torch.gather(self, 0, torch.zeros(2).long(), sparse=True).sum().backward()
to tests.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Feb 21, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator Author

ngimel commented Feb 27, 2019

@gchanan Anything I need to do on this PR? Is it failing internal tests?

@gchanan
Copy link
Contributor

gchanan commented Feb 27, 2019

@ngimel internal tests look good, let me try to merge it.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Feb 27, 2019
Summary:
This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function
```
def label_smoothing_opt(x, target):
    padding_idx = 0
    smoothing = 0.1
    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
    pad_mask = (target == padding_idx)
    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
    smooth_loss = logprobs.mean(dim=-1)
    loss =  (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
    loss.masked_fill_(pad_mask, 0)
    return loss.sum()
```
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!

cc gchanan apaszke
Pull Request resolved: pytorch/pytorch#17182

Differential Revision: D14158431

Pulled By: gchanan

fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade
@ngimel ngimel deleted the sparse_gather branch April 4, 2019 00:56
facebook-github-bot pushed a commit that referenced this pull request Jun 11, 2021
Summary:
Fixes an issue introduced in  #17182

Pull Request resolved: #59817

Reviewed By: bdhirsh

Differential Revision: D29040738

Pulled By: albanD

fbshipit-source-id: 67fd4e9fa0dadf507ddd954d20e119d8781c4de0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants