Skip to content

[c10d] Make alltoall as a custom op#79691

Closed
alanwaketan wants to merge 4 commits intogh/alanwaketan/40/basefrom
gh/alanwaketan/40/head
Closed

[c10d] Make alltoall as a custom op#79691
alanwaketan wants to merge 4 commits intogh/alanwaketan/40/basefrom
gh/alanwaketan/40/head

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Jun 16, 2022

Stack from ghstack (oldest at bottom):

Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda_complex
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_full_group_cuda
and other existing distributed tests.

Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
Existing distributed tests.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jun 16, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 1d79a54 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 16, 2022
alanwaketan pushed a commit that referenced this pull request Jun 16, 2022
Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
Existing distributed tests.

ghstack-source-id: 53b962b
Pull Request resolved: #79691
Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
Existing distributed tests.

[ghstack-poisoned]
alanwaketan pushed a commit that referenced this pull request Jun 16, 2022
Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda_complex
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_full_group_cuda
and other existing distributed tests.

ghstack-source-id: 4cb4d44
Pull Request resolved: #79691
Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda_complex
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_full_group_cuda
    and other existing distributed tests.

[ghstack-poisoned]
@alanwaketan alanwaketan requested a review from wanchaol June 22, 2022 23:46
@mrshenli mrshenli added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 24, 2022
}

c10::intrusive_ptr<ProcessGroup::Work> alltoall_(
at::TensorList output_tensors,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the type here is different from other comm ops?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not stamping yet due to this comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a convention to use TensorList to represent const std::vectorat::Tensor&.

I only use const std::vectorat::Tensor& if there is a const std::vector<std::vectorat::Tensor>& in the signature to keep consistency on that function. Let me know which way you like.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. can we add some docs to explain that in the code? Thanks!

const std::vector<std::vector<at::Tensor>>& input_tensors,
const ScatterOptions& opts ={});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work> alltoall(const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList output_tensors,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we run clang-format on this file. It might result in a different format.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pull request passes the linter. And the format seems consistent with other function signatures in the same file. What's your concern here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to run clang-format on all distributed cpp files. Not sure what's today's convention. Below is what clang-format gives me:

namespace c10d {
namespace ops {

// Below are essentially ProcessGroup's corresponding ops but routed to the
// dispatcher.
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
broadcast(const c10::intrusive_ptr<ProcessGroup> &process_group,
          at::TensorList tensors, const BroadcastOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
allreduce(const c10::intrusive_ptr<ProcessGroup> &process_group,
          at::TensorList tensors, const AllreduceOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
allgather(const c10::intrusive_ptr<ProcessGroup> &process_group,
          const std::vector<std::vector<at::Tensor>> &output_tensors,
          const std::vector<at::Tensor> &input_tensors,
          const AllgatherOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
reduce_scatter(const c10::intrusive_ptr<ProcessGroup> &process_group,
               const std::vector<at::Tensor> &output_tensors,
               const std::vector<std::vector<at::Tensor>> &input_tensors,
               const ReduceScatterOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
reduce(const c10::intrusive_ptr<ProcessGroup> &process_group,
       at::TensorList tensors, const ReduceOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
gather(const c10::intrusive_ptr<ProcessGroup> &process_group,
       const std::vector<std::vector<at::Tensor>> &output_tensors,
       const std::vector<at::Tensor> &input_tensors,
       const GatherOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
scatter(const c10::intrusive_ptr<ProcessGroup> &process_group,
        const std::vector<at::Tensor> &output_tensors,
        const std::vector<std::vector<at::Tensor>> &input_tensors,
        const ScatterOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
alltoall(const c10::intrusive_ptr<ProcessGroup> &process_group,
         at::TensorList output_tensors, at::TensorList input_tensors,
         const AllToAllOptions &opts = {});

} // namespace ops
} // namespace c10d

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow, my local linter doesn't produce the same output as yours. Let me just copy yours.

.def(
"alltoall",
[](::c10d::ProcessGroup& pg,
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on naming

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me follow up on #80246.

Copy link
Copy Markdown
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left a minor comment

const std::vector<std::vector<at::Tensor>>& input_tensors,
const ScatterOptions& opts ={});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work> alltoall(const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList output_tensors,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to run clang-format on all distributed cpp files. Not sure what's today's convention. Below is what clang-format gives me:

namespace c10d {
namespace ops {

// Below are essentially ProcessGroup's corresponding ops but routed to the
// dispatcher.
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
broadcast(const c10::intrusive_ptr<ProcessGroup> &process_group,
          at::TensorList tensors, const BroadcastOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
allreduce(const c10::intrusive_ptr<ProcessGroup> &process_group,
          at::TensorList tensors, const AllreduceOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
allgather(const c10::intrusive_ptr<ProcessGroup> &process_group,
          const std::vector<std::vector<at::Tensor>> &output_tensors,
          const std::vector<at::Tensor> &input_tensors,
          const AllgatherOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
reduce_scatter(const c10::intrusive_ptr<ProcessGroup> &process_group,
               const std::vector<at::Tensor> &output_tensors,
               const std::vector<std::vector<at::Tensor>> &input_tensors,
               const ReduceScatterOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
reduce(const c10::intrusive_ptr<ProcessGroup> &process_group,
       at::TensorList tensors, const ReduceOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
gather(const c10::intrusive_ptr<ProcessGroup> &process_group,
       const std::vector<std::vector<at::Tensor>> &output_tensors,
       const std::vector<at::Tensor> &input_tensors,
       const GatherOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
scatter(const c10::intrusive_ptr<ProcessGroup> &process_group,
        const std::vector<at::Tensor> &output_tensors,
        const std::vector<std::vector<at::Tensor>> &input_tensors,
        const ScatterOptions &opts = {});
TORCH_API c10::intrusive_ptr<ProcessGroup::Work>
alltoall(const c10::intrusive_ptr<ProcessGroup> &process_group,
         at::TensorList output_tensors, at::TensorList input_tensors,
         const AllToAllOptions &opts = {});

} // namespace ops
} // namespace c10d

Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda_complex
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_full_group_cuda
    and other existing distributed tests.

[ghstack-poisoned]
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge --green

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks for approving this pull request, Shen and Wanchao.

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Jun 29, 2022
Summary:
This patch makes alltoall as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Pull Request resolved: #79691
Approved by: https://github.com/mrshenli, https://github.com/wanchaol

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/dcd17357a4aa7ed175dc6cfd46f502499dd4bdb0

Test plan from GitHub:
BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_cuda_complex
    BACKEND=nccl WORLD_SIZE=2 python test/distributed/test_distributed_spawn.py -v TestDistBackendWithSpawn.test_all_to_all_full_group_cuda
    and other existing distributed tests.

Reviewed By: atalman

Differential Revision: D37455827

Pulled By: alanwaketan

fbshipit-source-id: 6745fd7d81a89b47e291da786e4511a6ce76be12
@facebook-github-bot facebook-github-bot deleted the gh/alanwaketan/40/head branch June 30, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants