Skip to content

Add all_gather_into_tensor in place of _all_gather_base#85686

Closed
kwen2501 wants to merge 2 commits intomasterfrom
all_gather_into_tensor
Closed

Add all_gather_into_tensor in place of _all_gather_base#85686
kwen2501 wants to merge 2 commits intomasterfrom
all_gather_into_tensor

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Sep 27, 2022

Description

  • This PR renames _all_gather_base to all_gather_into_tensor so that it is clearer in meaning.
  • The all_gather_into_tensor API differs from the all_gather API in the output it accepts -- a single, large tensor instead of a list of tensors.
  • This PR also adds deprecation warning to _all_gather_base.

Issue

_all_gather_base was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge _all_gather_base with the existing all_gather API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name tensor_list in all_gather to a general name output that can cover both tensor and tensor list.
(ii) Recently, the all_gather API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the _all_gather_base function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize _all_gather_base as a separate function, but with a clearer name.

Testing

Added tests:

  • test_all_gather_into_cat_tensor_cuda -- output form as with torch.cat. For example:
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
  • test_all_gather_into_stack_tensor_cuda -- output form as with torch.stack. For example:
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1

The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Sep 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85686

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 1 Pending

As of commit 5829bab:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Sep 27, 2022
@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 27, 2022
Copy link
Copy Markdown
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the update



def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the PyTorch web documentation by adding this to https://github.com/pytorch/pytorch/blob/master/docs/source/distributed.rst#collective-functions

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. Added now.

def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
"""
Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
All ranks gather tensors from all other ranks and put them into a single
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated now to "Gather tensors from all ranks and put them into ..."

Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here and land check progress here.
The merge job was triggered with the land checks (-l) flag. If you did not specify this flag yourself, you are likely enrolled in the land checks rollout. This means that your change will be merged once all checks on your PR and the land checks have passed (ETA 4 Hours). If you need to coordinate lands between different changes and cannot risk a land race, please add the ciflow/trunk label to your PR and wait for signal to complete, and then land your changes in proper order. Having trunk, pull, and Lint pre-run on a PR will bypass land checks and the ETA should be immediate. If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Copy Markdown
Contributor

Hey @kwen2501.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@kwen2501 kwen2501 added the topic: new features topic category label Sep 28, 2022
@ZainRizvi ZainRizvi added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR and removed ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Sep 28, 2022
drisspg pushed a commit to drisspg/pytorch that referenced this pull request Sep 29, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in pytorch#33924 to avoid unnecessary flattening. There was previous effort (pytorch#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: pytorch#85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
pytorchmergebot pushed a commit that referenced this pull request Sep 30, 2022
This is a twin PR similar to the one for `all_gather_into_tensor` (#85686).
The philosophy for renaming `_reduce_scatter_base` instead of merging it is described in #85686.

Cc @rohan-varma @H-Huang @crcrpar @ptrblck @mrshenli

Pull Request resolved: #85867
Approved by: https://github.com/crcrpar, https://github.com/H-Huang
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: #85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
This is a twin PR similar to the one for `all_gather_into_tensor` (#85686).
The philosophy for renaming `_reduce_scatter_base` instead of merging it is described in #85686.

Cc @rohan-varma @H-Huang @crcrpar @ptrblck @mrshenli

Pull Request resolved: #85867
Approved by: https://github.com/crcrpar, https://github.com/H-Huang
facebook-github-bot pushed a commit to facebookresearch/param that referenced this pull request Oct 27, 2022
Summary:
remove deprecated _all_gather_base and _reduce_scatter_base APIs and use the new replacements
- `_all_gather_base` -> `all_gather_into_tensor`
- _`reduce_scatter_base` -> `reduce_scatter_tensor`
- correct begin size for reduce_scatter_base

see pytorch/pytorch#85686 and pytorch/pytorch#85867 for the changes in PyTorch

Reviewed By: wesbland

Differential Revision: D40703655

fbshipit-source-id: c871c05a8de687a34a124b60857879c983b8cc3d
@github-actions github-actions bot deleted the all_gather_into_tensor branch March 29, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants