-
Notifications
You must be signed in to change notification settings - Fork 27.4k
[RFC][distributed] RFC: c10d ProcessGroup extension and C++ API change #39662
Description
Proposal: c10d ProcessGroup extension and C++ API change
Purpose
Clients to provide custom implementations of ProcessGroup API and use them as CCL backend without needing to modify PyTorch codebase. Registration mechanism is implemented in #28068.
Usage example
Implement constructor of a single rank in your custom ProcessGroup called custom_backend implemented in c++ module custom_backend_cpp_module (c++):
Register your backend (c++, pybind):
`dist.Backend.register_backend("**custom_backend**``", py::cpp_function(`**`GetNewRank`**`));`
Load your module (py):
`torch``.``utils``.``cpp_extension``.``load``(``name``=``"**custom_backend_cpp_module**``"``,`` sources``=[...])`
Initialize process group (py):
`dist``.``init_process_group``(``backend``=``"**custom_backend**``"``,`` init_method ``=`` ``...)`
Proposed changes to the current API
ProcessGroup API has some inconsistencies, so we’d like to make it more uniform across different functionalities.
For some operations like allgather and reducescatter we allow in-place mode by means of user can specifying in_place = true in options of the corresponding operation. In this mode, user guarantees that inputs are unchanged until the operation completes. Implementors of ProcessGroup can (and are encouraged to) take advantage of this mode to speed up their implementations, e.g. see #33924.
Some operations have multi-gpu versions:
reduce, allreduce, gather where a single process group controls multiple GPUs. These are mostly used by data parallel. [Should we mention: For newer applications, we suggest to have one rank of process group per GPU?]
What we are changing in new API to make it more straightforward:
-
Batched operations (main purpose is to add efficiency by combining tensors) are going to live outside of ProcessGroup API so there is no need to implement them by 3rd party:
allreduce_coalesced, allgather_coalesced -
adding single-GPU versions of allreduce, gather, scatter
-
removing multi-GPU versions of scatter, gather that aren’t supported by any of existing backends.
All to all
Current API (keeping).
virtual std::shared_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) = 0;
Current support: MPI only. (Clients are encouraged to do optimizations based on tensor layout, i.e. not making extra tensor copies if tensors share storage and properly aligned)
Current API (removing):
virtual std::shared_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) = 0;
All reduce
Current API (keeping until status of multigpu PG is clear):
allreduce(std::vector<at::Tensor>& data, const AllreduceOptions& opts) = 0;
Current support: coalesced version is supported in GLOO only.
Adding single tensor API:
// Not inlcuded in API at the moment: allreduce a single tensor.
allreduce(at::Tensor& tensor, const AllreduceOptions& opts) = 0;
No longer a part of API: all reduce coalesced, it will be moved away from process group (no need for 3rd parties to implement) and included in comms instead.
// Being deprecated and moved to comms, can be done via allreduce_base.
allreduce_coalesced(std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts);
Gather
Current API (not keeping):
gather(std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) = 0;
*Current support: *MPI, GLOO only supports single element gather, NCCL does not support gather.
Desirable API:
gather(std::vector<at::Tensor>& outputTensors,
at::Tensor& inputTensor,
const GatherOptions& opts) = 0;
All gather
Current API:
virtual std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions())= 0;
Current support: coalesced version is supported in GLOO only.
Desirable API (adding)
virtual std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<at::Tensor>& outputTensors,
at::Tensor& inputTensor,
const AllgatherOptions& opts = AllgatherOptions()) = 0;
(Clients are encouraged to do optimizations based on tensor layout, i.e. not making extra tensor copies if tensors share storage and properly aligned)
Not part of API (same as with allreduce_coalesced)
// Being deprecated and moved to comms, can be done via allgather_base.
virtual std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions())ß;
Scatter
Current API (not keeping):
virtual std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) = 0;
*Current support: *GLOO, MPI supports single tensor version only, NCCL does not support scatter.
Desired API (adding):
virtual std::shared_ptr<ProcessGroup::Work> scatter(
at::Tensor& outputTensor,
std::vector<at::Tensor>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) = 0;
No changes:
Reduce
reduce(std::vector<at::Tensor>& tensors, const ReduceOptions& opts) = 0;
Current support: multiple tensors for usecase where ‘tensors’ should live on different GPUs, GLOO, MPI supports single tensor input only. NCCL supports multiple tensors.
ReduceScatter
Current API:
virtual std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0;
Current support: NCCL only.
Broadcast
virtual std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) = 0;
Barrier
barrier(const BarrierOptions& opts) = 0;
Send/receive
virtual std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) = 0;
virtual std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) = 0;
virtual std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) = 0;
Backend/feature support table
| Backend | gloo |
mpi |
nccl |
|||
|---|---|---|---|---|---|---|
| Device | CPU | GPU | CPU | GPU | CPU | GPU |
| send | ✓ | ✘ | ✓ | ? | ✘ | ✘ |
| recv | ✓ | ✘ | ✓ | ? | ✘ | ✘ |
| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
| reduce | ✓ | ✘ | ✓ | ? | ✘ | ✓ |
| all_gather | ✓ | ✘ | ✓ | ? | ✘ | ✓ |
| gather | ✓ | ✘ | ✓ | ? | ✘ | ✘ |
| scatter | ✓ | ✘ | ✓ | ? | ✘ | ✘ |
| reduce_scatter | ✘ | ✘ | ✘ | ✘ | ✘ | ✓ |
| all_to_all | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ |
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar