Skip to content

Add reduce_scatter_tensor in place of _reduce_scatter_base#85867

Closed
kwen2501 wants to merge 2 commits intomasterfrom
reduce_scatter_tensor
Closed

Add reduce_scatter_tensor in place of _reduce_scatter_base#85867
kwen2501 wants to merge 2 commits intomasterfrom
reduce_scatter_tensor

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

This is a twin PR similar to the one for all_gather_into_tensor (#85686).
The philosophy for renaming _reduce_scatter_base instead of merging it is described in #85686.

Cc @rohan-varma @H-Huang @crcrpar @ptrblck @mrshenli

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Sep 28, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85867

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 945f466:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Sep 28, 2022
@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 28, 2022
Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, appreciate your swift action

Copy link
Copy Markdown
Contributor

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

)
return tensor_out

@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this decorator @requires_nccl() does a similar check

Suggested change
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor")
@requires_nccl()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix later

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 29, 2022
tensor([0, 2], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1

.. warning::
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix later.

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered with the land checks (-l) flag. If you did not specify this flag yourself, you are likely enrolled in the land checks rollout. This means that your change will be merged once all checks on your PR have passed since you have added the ciflow/trunk label to your PR (ETA 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.2-py3.7 / test (default, 1, 2, linux.rocm.gpu)

If you believe this is an error, you can use the old behavior with @pytorchbot merge -g (optionally with the ciflow/trunk to get land checks) or use @pytorchbot merge -f "some reason here". For more information, see the bot wiki.

Please reach out to the PyTorch DevX Team with feedback or questions!

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase -m

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Sep 29, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: -m

usage: @pytorchbot [-h] {merge,revert,rebase,label} ...

Try @pytorchbot --help for more info.

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase -b master

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased reduce_scatter_tensor onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout reduce_scatter_tensor && git pull --rebase)

@kwen2501
Copy link
Copy Markdown
Collaborator Author

Here is the failed check:

TestCommonCUDA.test_python_ref_executor__refs_atanh_executor_aten_cuda_complex32`

It does not seem related to my change.

And the failure reason is:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 7.44 GiB total capacity; 312.49 MiB already allocated; 10.19 MiB free; 1.64 GiB allowed; 860.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Also does not seem to be related to my change.

@huydhn FYI.

@pytorchbot merge -f "The one failure does not seem related to my change"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered with the force (-f) flag. This means your change will be merged immediately, bypassing any CI checks (ETA: 1-5 minutes). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Copy Markdown
Contributor

Hey @kwen2501.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@kwen2501 kwen2501 added the topic: new features topic category label Sep 30, 2022
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
This is a twin PR similar to the one for `all_gather_into_tensor` (#85686).
The philosophy for renaming `_reduce_scatter_base` instead of merging it is described in #85686.

Cc @rohan-varma @H-Huang @crcrpar @ptrblck @mrshenli

Pull Request resolved: #85867
Approved by: https://github.com/crcrpar, https://github.com/H-Huang
facebook-github-bot pushed a commit to facebookresearch/param that referenced this pull request Oct 27, 2022
Summary:
remove deprecated _all_gather_base and _reduce_scatter_base APIs and use the new replacements
- `_all_gather_base` -> `all_gather_into_tensor`
- _`reduce_scatter_base` -> `reduce_scatter_tensor`
- correct begin size for reduce_scatter_base

see pytorch/pytorch#85686 and pytorch/pytorch#85867 for the changes in PyTorch

Reviewed By: wesbland

Differential Revision: D40703655

fbshipit-source-id: c871c05a8de687a34a124b60857879c983b8cc3d
@github-actions github-actions bot deleted the reduce_scatter_tensor branch March 31, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants