Skip to content

[c10d] Make allreduce as a custom op#79582

Closed
alanwaketan wants to merge 3 commits intogh/alanwaketan/34/basefrom
gh/alanwaketan/34/head
Closed

[c10d] Make allreduce as a custom op#79582
alanwaketan wants to merge 3 commits intogh/alanwaketan/34/basefrom
gh/alanwaketan/34/head

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Jun 14, 2022

Stack from ghstack (oldest at bottom):

Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jun 14, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit 771024f (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-22T21:44:06.7073075Z ##[error]Process completed with exit code 1.
2022-06-22T21:44:06.4932978Z �[1A�[K�[32mINFO: �[0mElapsed time: 164.312s
2022-06-22T21:44:06.4933434Z �[32mLoading:�[0m 0 packages loaded
2022-06-22T21:44:06.4938218Z 
2022-06-22T21:44:06.4939582Z �[1A�[K�[32mINFO: �[0m0 processes.
2022-06-22T21:44:06.4940067Z �[32mLoading:�[0m 0 packages loaded
2022-06-22T21:44:06.4940895Z 
2022-06-22T21:44:06.4941701Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (0 packages loaded)
2022-06-22T21:44:06.4976872Z 
2022-06-22T21:44:06.4982914Z �[1A�[K�[31m�[1mFAILED:�[0m Build did NOT complete successfully (0 packages loaded)
2022-06-22T21:44:06.5099123Z �[0mFailed to build external libraries: ['/var/lib/jenkins/workspace/xla/build_torch_xla_libs.sh', '-O', '-D_GLIBCXX_USE_CXX11_ABI=1', 'install']
2022-06-22T21:44:06.7073075Z ##[error]Process completed with exit code 1.
2022-06-22T21:44:06.7115473Z Prepare all required actions
2022-06-22T21:44:06.7115786Z Getting action download info
2022-06-22T21:44:06.8730430Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-06-22T21:44:06.8730653Z with:
2022-06-22T21:44:06.8730989Z   github-token: ***
2022-06-22T21:44:06.8731144Z env:
2022-06-22T21:44:06.8731319Z   GIT_DEFAULT_BRANCH: master
2022-06-22T21:44:06.8731507Z ##[endgroup]
2022-06-22T21:44:06.8757087Z ##[group]Run nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a
2022-06-22T21:44:06.8757308Z with:

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 14, 2022
@alanwaketan alanwaketan requested review from aazzolini, ezyang and wconstab and removed request for H-Huang, awgu, mingzhe09088, mrshenli, rohan-varma and zhaojuanmao June 15, 2022 00:03
Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

[ghstack-poisoned]
alanwaketan pushed a commit that referenced this pull request Jun 15, 2022
Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

ghstack-source-id: 1b1d2c9
Pull Request resolved: #79582
auto allreduce_fut =
ops::allreduce(
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(
state_),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to convert a raw pointer to an intrusive_ptr. Is this the way to do so?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, who owns this PG instance? I assume it is owned by Python PG object? If that's the case, will this mess up the refcnt. What happens when this tmp intrusive ptr exits scope?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PG are normally owned by a Python object. AllReduceCommHook somehow holds a ProcessGroup* instead of intrusive_ptr. Therefore, I need to convert the raw pointer to a intrusive_ptr.

I don't believe this will mess up the refcnt. However, I actually think it's better to replace class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> with class AllReduceCommHook : public CppCommHookInterface<intrusive_ptr<ProcessGroup>>. What do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this will mess up the refcnt.

If we create an intrusive ptr from the raw ptr, does this mean we have two separate entities tracking refcnt for the same raw ptr separately? One is the Python object, and another is this intrusive ptr?

I actually think it's better to replace class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> with class AllReduceCommHook : public CppCommHookInterface<intrusive_ptr>. What do you think?

Yep, this does sounds better to me.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for intrusive_ptr the refcnt is stored in the object (ProcessGroup) itself. Intrusive_ptr is just a way to increment/decrement the refcnt. So it shouldn't matter.

Let me make a follow up patch on changing class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> to class AllReduceCommHook : public CppCommHookInterface<intrusive_ptr<ProcessGroup>>.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Can we also add a comment for this in the code? Thank you!

Copy link
Copy Markdown
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I can add tests to have a python tensor that override torch_dispatch to directly verify that.

I see, I guess we can add those tests later in a separate PR when we necessarily need it. There's two things on top of my head and need some inputs from @mrshenli, as these might be related to the actual node appears in the IR, we should get some clarify and make them consistent:

  1. about operator suffix and argument ordering: should we make the aten operator follow our python level API, or should we follow the ATen operator naming convention?
  2. should we let wait() appear in the IR? This might be related to how the cuda stream sync works in our current tracer.

root_rank, root_tensor, std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<ProcessGroup::Work> allreduce_(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks! One thing that captured my eyes about this TorchBind Work object, does not have methods like wait() binded, I guess this is fine initially as this PR is more about making it a dispatcher level op.

But I am wondering how this would be look like in our traced IR, should we have the wait() in the graph? how does the traced graph look like if we need async execution on a different cuda stream where user usually need to manually wait for stream? cc @mrshenli

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

  1. about operator suffix and argument ordering: should we make the aten operator follow our python level API, or should we follow the ATen operator naming convention?

I think we should follow the aten convention for the funtion schema as that will be easier for any tracer to interpret the ops. At least AOT would assume the aten convention.

  1. should we let wait() appear in the IR? This might be related to how the cuda stream sync works in our current tracer.

Please see my other comments for the short term solution. Long term wise, yes we need a way to represent cuda streams in the graph. We don't know how yet.

Copy link
Copy Markdown
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, looks like the CI failure is real:

Broken ops: [
	c10d::broadcast(__torch__.torch.classes.c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3, int _4) -> __torch__.torch.classes.c10d.Work _0
]

Could you fix the CI issue before landing? Thanks!

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Looks good to me, looks like the CI failure is real:

Broken ops: [
	c10d::broadcast(__torch__.torch.classes.c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3, int _4) -> __torch__.torch.classes.c10d.Work _0
]

Could you fix the CI issue before landing? Thanks!

Thanks, Wanchao. I believe it's intended to break the schema. Do you know how to update the test expectation of the backward_compat test?

Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

[ghstack-poisoned]
alanwaketan pushed a commit that referenced this pull request Jun 22, 2022
Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

ghstack-source-id: 220759b
Pull Request resolved: #79582
("aten::segment_reduce", datetime.date(2022, 6, 30)),
("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
("aten::empty.SymInt", datetime.date(9999, 1, 1)),
("c10d::broadcast", datetime.date(2022, 6, 25)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is the correct way :) although i am not sure which namespace broadcast got binded to, looks like it's c10d and we can see if this get it passed

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

The XLA failure doesn't seem to be related.

@pytorch pytorch deleted a comment from pytorch-bot bot Jun 23, 2022
@pytorch pytorch deleted a comment from pytorchmergebot Jun 23, 2022
@pytorch pytorch deleted a comment from pytorchmergebot Jun 23, 2022
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -f

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Copy Markdown
Contributor

Hey @alanwaketan.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 23, 2022
Summary:
This patch makes allreduce as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Pull Request resolved: #79582

Approved by: https://github.com/wanchaol

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e5841bafbd2868eaf6eb7b89b4caf3a6261dcfa6

Test plan from GitHub:
python test/distributed/test_c10d_nccl.py -k test_allreduce_ops
python test/distributed/test_c10d_gloo.py -k test_allreduce_basics
...and other existing distributed tests.

Reviewed By: atalman

Differential Revision: D37382098

Pulled By: alanwaketan

fbshipit-source-id: 068fd6d8f2c3fa3998431dcf878e14bd41890693
@mrshenli
Copy link
Copy Markdown
Contributor

should we let wait() appear in the IR? This might be related to how the cuda stream sync works in our current tracer.

Can this be represented as edges in the graph?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

should we let wait() appear in the IR? This might be related to how the cuda stream sync works in our current tracer.

Can this be represented as edges in the graph?

I think we need more discussions on this. Let me try to organize a follow up meeting.

@facebook-github-bot facebook-github-bot deleted the gh/alanwaketan/34/head branch June 26, 2022 14:16
crcrpar pushed a commit to crcrpar/pytorch that referenced this pull request Aug 22, 2022
Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: ptrblck <[email protected]>
Co-authored-by: Michael Carilli <[email protected]>

Patch for pytorch#79582

Apparently 79852 is newer than 34 and the commit below so the PR assumes
`ReduceOp` to be an `enum`, not a `struct` including an `enum` inside
it.
crcrpar pushed a commit to crcrpar/pytorch that referenced this pull request Aug 24, 2022
Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: ptrblck <[email protected]>
Co-authored-by: Michael Carilli <[email protected]>

Patch for pytorch#79582

Apparently 79852 is newer than 34 and the commit below so the PR assumes
`ReduceOp` to be an `enum`, not a `struct` including an `enum` inside
it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants