Skip to content

Conversation

@pietern
Copy link
Contributor

@pietern pietern commented Jun 20, 2019

Stack:
    :white_circle:  #22037 Support sparse gradients in DistributedDataParallel  💛
    :black_circle:  #22036 Add sparse tensor allreduce  💛

Implemented only on ProcessGroupGloo, as an allgather of metadata
(sparse_dim, dense_dim, and nnz), followed by an allgather of indices,
followed by an allgather of values. Once these operations have
finished, all ranks locally compute a reduction over these sparse
tensors. Works for both CPU and CUDA tensors.

This surfaced a problem with the existing assumption of only modifying
tensors that are passed at the call site, because for sparse tensors
we don't know the dimensions of the output tensors before we run the
collective. To deal with this unknown, this commit adds a result
function to the c10d::ProcessGroup::Work class that returns a vector
of tensors.

It's a bit odd to have to retrieve the result through this function
only for operations on sparse tensors. To make this work irrespective
of tensor layout, we can create a follow-up commit to make all in
place operations make their results accessible through this function
as well. This doesn't break any existing contracts but does have the
potential to add interface ambiguity.

This is a resubmission of #19146.

Differential Revision: D15926384

Differential Revision: D15926384
Differential Version: 85311082
@pietern pietern requested review from apaszke and mrshenli as code owners June 20, 2019 19:37
@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 20, 2019
@pietern
Copy link
Contributor Author

pietern commented Jun 21, 2019

@pytorchbot retest this please

The Windows failures are likely unrelated and to be fixed by #22029.

@pietern pietern added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 21, 2019
@pietern
Copy link
Contributor Author

pietern commented Jun 24, 2019

@pytorchbot retest this please

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in a7ec889.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants