Skip to content

Conversation

@pietern
Copy link
Contributor

@pietern pietern commented Apr 11, 2019

Stack:
    :white_circle:  #19443 Support sparse gradients in DistributedDataParallel  💛
    :black_circle:  #19146 Add sparse tensor allreduce  💛

Implemented only on ProcessGroupGloo, as an allgather of metadata
(sparse_dim, dense_dim, and nnz), followed by an allgather of indices,
followed by an allgather of values. Once these operations have
finished, all ranks locally compute a reduction over these sparse
tensors. Works for both CPU and CUDA tensors.

This surfaced a problem with the existing assumption of only modifying
tensors that are passed at the call site, because for sparse tensors
we don't know the dimensions of the output tensors before we run the
collective. To deal with this unknown, this commit adds a result
function to the c10d::ProcessGroup::Work class that returns a vector
of tensors.

It's a bit odd to have to retrieve the result through this function
only for operations on sparse tensors. To make this work irrespective
of tensor layout, we can create a follow-up commit to make all in
place operations make their results accessible through this function
as well. This doesn't break any existing contracts but does have the
potential to add interface ambiguity.

Differential Revision: D14889547

Differential Revision: D14889547
Differential Version: 78974033
@pietern pietern added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 11, 2019
Differential Revision: D14889547
Differential Version: 79084183
} else {
// We will need to coalesce first, which means new tensors will
// be allocated on the streams we just allocated, and there
// is no need to record them separately.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the op is not synchronized, tensors can be freed even before coalesce happens. Isn't that still a problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't be freed because the op holds on to the input tensors. The coalesce on this new stream will happen after any operations against the inputs because of events[i].block() that is run before this block.

//
// - [0:4]: sparse dims
// - [4:8]: dense dims
// - [8]: nnz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is nnz? number of non-zero elements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

//
// The layout of this memory is as follows:
//
// - [0:4]: sparse dims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When merging this PR, let's create an issue to generalize it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The reason I picked 4 here is that I never saw anything bigger and we need to have a constant maximum number that is shared between processes, or else require another collective before we can exchange these details.


// Gather all indices and all values.
auto indices = allgather_indices(input, metadata);
auto values = allgather_values(input, metadata);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these two be done in one allgather?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be possible in theory, since we have all data. But the data types may be different (the indices are always long and the values may be anything). To combine them in a single call we'd need to cast both of them to char buffers since it would run on the byte level.

// Copy back to input tensors.
outputs.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
outputs.push_back(output.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only for all_reduce_multigpu API? If yes, should we throw error instead if inputs.size() > 1? Because the docs says

The function operates in-place and requires that each tensor to be a GPU tensor on different GPUs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update and deprecate all_reduce_multigpu instead, IMO. We already have the multi-input collectives (for some at least), so I added this one here for parity with those. If we end up with a replicated DDP module with an embedding bag, the reducer will end up calling this multiple sparse inputs. This functionality needs to live somewhere and I rather have it here than move it back into the reducer where we'd compute a sum across replicas before passing it to the process group. Doing so would introduce an asymmetry with ProcessGroupNCCL, where multi input allreduce is valid (and will defer summing them to NCCL).

/*non_blocking=*/true,
/*copy=*/true);

outputs.push_back(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not writing back to the input tensors for sparse use case, we should edit Python docs accordingly. BTW, we can still write to input tensors using tensor.set_()?

pietern added 2 commits April 15, 2019 11:38
Differential Revision: D14889547
Differential Version: 79475049
Differential Revision: D14889547
Differential Version: 80027945
pietern added 2 commits June 19, 2019 02:17
Differential Revision: D14889547
Differential Version: 85199091
Differential Revision: D14889547
Differential Version: 85201710
@pietern
Copy link
Contributor Author

pietern commented Jun 19, 2019

@mrshenli The variable to tensor unboxing is no longer needed now that there has been some progress with the tensor/variable merge. The update also removes some verbose copying code (separate copy calls for the indices and values) and replaces those with a straightforward tensor.to().

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in aee6a41.

facebook-github-bot pushed a commit that referenced this pull request Jun 24, 2019
Summary:
Pull Request resolved: #22036

Implemented only on ProcessGroupGloo, as an allgather of metadata
(sparse_dim, dense_dim, and nnz), followed by an allgather of indices,
followed by an allgather of values. Once these operations have
finished, all ranks locally compute a reduction over these sparse
tensors. Works for both CPU and CUDA tensors.

This surfaced a problem with the existing assumption of only modifying
tensors that are passed at the call site, because for sparse tensors
we don't know the dimensions of the output tensors before we run the
collective. To deal with this unknown, this commit adds a `result`
function to the `c10d::ProcessGroup::Work` class that returns a vector
of tensors.

It's a bit odd to have to retrieve the result through this function
only for operations on sparse tensors. To make this work irrespective
of tensor layout, we can create a follow-up commit to make all in
place operations make their results accessible through this function
as well. This doesn't break any existing contracts but does have the
potential to add interface ambiguity.

This is a resubmission of #19146.

Reviewed By: mrshenli

Differential Revision: D15926384

fbshipit-source-id: b6ee5d81606bfa8ed63c3d63a9e307613491e0ae
@lironmo
Copy link

lironmo commented Jun 26, 2019

@pietern when is this excepted to be merge and what will be the rc version?

@pietern
Copy link
Contributor Author

pietern commented Jun 26, 2019

@lironmo This was merged in #22036 and you can use it today with the nightly builds.

@lironmo
Copy link

lironmo commented Jul 1, 2019

@pietern thanks, i tried it on gloo backend and it's works.

I have a few questions:
1.can i use it with nccl backend?
2.it's seems that the support of sparse tensors are low (mainly missing serialization - i converted to coo matrix and use pickle) - you are going to support torch.save?
3. probably not related to you but it's seems that there is a major change in the lr scheduler(related to the base lr, up until this version the lr_scheduler.step(epoch_metrics) done before the optimizer.step(), but it's break the multi step scheduler...

@pietern
Copy link
Contributor Author

pietern commented Jul 1, 2019

Hi @lironmo ,

  1. No, this is implemented only for the Gloo backend at this point in time. Added Sparse allreduce for ProcessGroupNCCL #22400 to track.
  2. Can you create an issue for missing functionality? Then we'll be able to track and prioritize it.
  3. This doesn't sound related to sparse allreduce. You still need an optimizer that works with sparse gradients, of course.

@lironmo
Copy link

lironmo commented Jul 1, 2019

thanks.

@ezyang ezyang deleted the export-D14889547 branch July 19, 2019 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants