Skip to content

[c10d] Make reduce_scatter as a custom op#79683

Closed
alanwaketan wants to merge 3 commits intogh/alanwaketan/36/basefrom
gh/alanwaketan/36/head
Closed

[c10d] Make reduce_scatter as a custom op#79683
alanwaketan wants to merge 3 commits intogh/alanwaketan/36/basefrom
gh/alanwaketan/36/head

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Jun 16, 2022

Stack from ghstack (oldest at bottom):

Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jun 16, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit e73037d (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-22T23:42:17.3989535Z ##[error]Process completed with exit code 1.
2022-06-22T23:42:17.1653945Z �[1A�[K�[32mINFO: �[0mElapsed time: 164.154s
2022-06-22T23:42:17.1654249Z �[32mLoading:�[0m 0 packages loaded
2022-06-22T23:42:17.1659709Z 
2022-06-22T23:42:17.1660277Z �[1A�[K�[32mINFO: �[0m0 processes.
2022-06-22T23:42:17.1660633Z �[32mLoading:�[0m 0 packages loaded
2022-06-22T23:42:17.1660800Z 
2022-06-22T23:42:17.1661092Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (0 packages loaded)
2022-06-22T23:42:17.1691590Z 
2022-06-22T23:42:17.1696486Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (0 packages loaded)
2022-06-22T23:42:17.1820820Z �[0mFailed to build external libraries: ['/var/lib/jenkins/workspace/xla/build_torch_xla_libs.sh', '-O', '-D_GLIBCXX_USE_CXX11_ABI=1', 'install']
2022-06-22T23:42:17.3989535Z ##[error]Process completed with exit code 1.
2022-06-22T23:42:17.4047307Z Prepare all required actions
2022-06-22T23:42:17.4047601Z Getting action download info
2022-06-22T23:42:17.5466698Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-06-22T23:42:17.5466923Z with:
2022-06-22T23:42:17.5467334Z   github-token: ***
2022-06-22T23:42:17.5467488Z env:
2022-06-22T23:42:17.5467659Z   GIT_DEFAULT_BRANCH: master
2022-06-22T23:42:17.5467845Z ##[endgroup]
2022-06-22T23:42:17.5493918Z ##[group]Run nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a
2022-06-22T23:42:17.5494156Z with:

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 16, 2022
@alanwaketan alanwaketan requested review from aazzolini, ezyang and wconstab and removed request for H-Huang, awgu, mingzhe09088, mrshenli, rohan-varma and zhaojuanmao June 16, 2022 05:32
Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

[ghstack-poisoned]
alanwaketan pushed a commit that referenced this pull request Jun 16, 2022
Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

ghstack-source-id: b9afad5
Pull Request resolved: #79683
@alanwaketan alanwaketan requested a review from mrshenli June 16, 2022 20:56
Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

[ghstack-poisoned]
@alanwaketan alanwaketan requested a review from wanchaol June 22, 2022 23:44
@mrshenli mrshenli added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 24, 2022
Copy link
Copy Markdown
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Same suggestions on tests as the PR below this one.

Comment on lines +147 to +153
.findSchemaOrThrow("c10d::reduce_scatter_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
const std::vector<at::Tensor>&,
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct that the overhead of these two ops are negligible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be negligible. And then we cache it for the consecutive calls to minimize the overhead too.

.def(
"reduce_scatter",
[](::c10d::ProcessGroup& pg,
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's stay consistent on the naming? This one is using pg. The one above is using self.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this pg is user defined, and the above is really self. That's why there is this inconsistency.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me follow up on #80246.

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks Shen for approving this pull request. The CI failure is unrelated.

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -f

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Copy Markdown
Contributor

Hey @alanwaketan.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2022
Summary:
This patch makes reduce_scatter as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Pull Request resolved: #79683
Approved by: https://github.com/mrshenli

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/80b50dfa3ab16de3e90dab8eeed003a98a0da1fe

Test plan from GitHub:
python test/distributed/test_c10d_nccl.py -k test_reduce_scatter_ops

Reviewed By: atalman

Differential Revision: D37455687

Pulled By: alanwaketan

fbshipit-source-id: bb78183e5cf5798b6b558488eed07af7cd4d4eff
@facebook-github-bot facebook-github-bot deleted the gh/alanwaketan/36/head branch June 28, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants