Skip to content

all_gather supporting single tensor output#82639

Closed
kwen2501 wants to merge 5 commits intopytorch:masterfrom
kwen2501:all_gather_single_output_tensor
Closed

all_gather supporting single tensor output#82639
kwen2501 wants to merge 5 commits intopytorch:masterfrom
kwen2501:all_gather_single_output_tensor

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Aug 2, 2022

Description

This PR unifies the all_gather API and the _all_gather_base API by having the former support both single tensor and list of tensor as output.

Issue

Following the request of avoiding unnecessary flattening in #33924, _all_gather_base was implemented. That API, however, has never been productized, due to the concern of creating multiple all-gather variants. Thus, instead of productizing _all_gather_base, this PR uses all_gather to cover both, because in traditional sense (as in MPI or NCCL), all_gather can indeed stand for gathering multiple small buffers into a single big buffer.

Testing

Added tests:
test_all_gather_single_output_tensor_cuda

Context:

In recent discussion with @ptrblck @crcrpar regarding use of the base feature in Megatron, it was identified that unifying the two APIs was preferred to productizing _all_gather_base. Also per the initial idea to unify them as mentioned by @zhaojuanmao when _all_gather_base was first implemented #56315 (review)

There are two ways to unifying the two APIs.
Method 1: have all_gather accept a single tensor
Method 2: have all_gather accept a list, the list containing a single tensor
It appears to me that Method 1 is cleaner in semantics -- when user passes in a single tensor, it is explicit that they ask for merging the input tensors into a (larger) output tensor. Also, since this is something user passes in rather than a return of the function, there is no need for user to figure out what the return is (a list or a tensor).

Regarding why a single function name:
I want to give user an experience that they only need to think about what output format they want, and not having to additionally think about which function to call (especially when there are already several all_gather_* APIs there). In retrospective, it may also seem that the list implementation "hijacked" the all_gather name, whereas the single tensor implementation can be more natural, smoother in application use flow maybe, and more performant. Hence my thought of giving the tensor output the "first-class" name of all_gather.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Aug 2, 2022

🔗 Helpful links

❌ 3 New Failures, 2 Pending

As of commit 9dad87f (more details on the Dr. CI page):

Expand to see more
  • 3/3 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (functorch, 1, 1, windows.4xlarge) (1/3)

Step: "Test" (full log | diagnosis details)

2022-08-04T12:39:50.9721006Z RuntimeError: C:\a...rk\pytorch\pytorch\functorch\test\test_ops failed!
2022-08-04T12:39:50.4784411Z 
2022-08-04T12:39:50.4784543Z FAILED (errors=10, skipped=1620, expected failures=304)
2022-08-04T12:39:50.4784699Z 
2022-08-04T12:39:50.4784791Z Generating XML reports...
2022-08-04T12:39:50.4785160Z Generated XML report: test-reports\python-unittest\functorch\test\test_ops\TEST-TestOperatorsCPU-20220804122724.xml
2022-08-04T12:39:50.9719683Z Traceback (most recent call last):
2022-08-04T12:39:50.9720058Z   File "run_test.py", line 973, in <module>
2022-08-04T12:39:50.9720263Z     main()
2022-08-04T12:39:50.9720457Z   File "run_test.py", line 951, in main
2022-08-04T12:39:50.9720695Z     raise RuntimeError(err_message)
2022-08-04T12:39:50.9721006Z RuntimeError: C:\actions-runner\_work\pytorch\pytorch\functorch\test\test_ops failed!
2022-08-04T12:39:51.1801533Z 
2022-08-04T12:39:51.1802242Z (base) C:\actions-runner\_work\pytorch\pytorch\test>popd
2022-08-04T12:39:51.1805903Z 
2022-08-04T12:39:51.1806156Z (base) C:\actions-runner\_work\pytorch\pytorch>if ERRORLEVEL 1 goto fail 
2022-08-04T12:39:51.1808025Z 
2022-08-04T12:39:51.1808208Z (base) C:\actions-runner\_work\pytorch\pytorch>exit /b 1 
2022-08-04T12:39:51.1862793Z ##[error]Process completed with exit code 1.
2022-08-04T12:39:51.1997071Z Prepare all required actions
2022-08-04T12:39:51.1997587Z Getting action download info
2022-08-04T12:39:51.3570334Z Download action repository 'nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a' (SHA:71062288b76e2b6214ebde0e673ce0de1755740a)

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (2/3)

Step: "Test" (full log | diagnosis details)

2022-08-04T12:26:22.1951655Z AssertionError: Th...functional.prelu on device type cpu are incorrect!
2022-08-04T12:26:22.1941256Z [gw1] [  1%] FAILED test_ops.py::TestCommonCPU::test_dtypes__refs_nn_functional_prelu_cpu 
2022-08-04T12:26:22.1941442Z 
2022-08-04T12:26:22.1941556Z ================================== FAILURES ===================================
2022-08-04T12:26:22.1941849Z ___________ TestCommonCPU.test_dtypes__refs_nn_functional_prelu_cpu ___________
2022-08-04T12:26:22.1942198Z [gw1] win32 -- Python 3.8.13 C:\Jenkins\Miniconda3\python.exe
2022-08-04T12:26:22.1942438Z Traceback (most recent call last):
2022-08-04T12:26:22.1950359Z   File "C:\actions-runner\_work\pytorch\pytorch\test\test_ops.py", line 1227, in test_dtypes
2022-08-04T12:26:22.1950735Z     self.fail(msg)
2022-08-04T12:26:22.1951037Z   File "C:\Jenkins\Miniconda3\lib\unittest\case.py", line 753, in fail
2022-08-04T12:26:22.1951340Z     raise self.failureException(msg)
2022-08-04T12:26:22.1951655Z AssertionError: The supported dtypes for _refs.nn.functional.prelu on device type cpu are incorrect!
2022-08-04T12:26:22.1952086Z The following dtypes did not work in forward but are listed by the OpInfo: {torch.float64, torch.bfloat16, torch.float32}.
2022-08-04T12:26:22.1952312Z 
2022-08-04T12:26:22.1952552Z - generated xml file: C:\actions-runner\_work\pytorch\pytorch\test\test-reports\python-pytest\test_ops\test_ops.xml -
2022-08-04T12:26:22.1952914Z =========================== short test summary info ===========================
2022-08-04T12:26:22.1953213Z FAILED test_ops.py::TestCommonCPU::test_dtypes__refs_nn_functional_prelu_cpu
2022-08-04T12:26:22.1953520Z !!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!
2022-08-04T12:26:22.1953822Z !!!!!!!!!!!! xdist.dsession.Interrupted: stopping after 1 failures !!!!!!!!!!!!
2022-08-04T12:26:22.1954117Z ======= 1 failed, 240 passed, 7 skipped, 42 warnings, 2 rerun in 21.65s =======
2022-08-04T12:26:22.7351361Z Traceback (most recent call last):
2022-08-04T12:26:22.7351725Z   File "run_test.py", line 973, in <module>

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (3/3)

Step: "Test" (full log | diagnosis details)

2022-08-04T12:34:29.1324535Z RuntimeError: test_ops_gradients failed!
2022-08-04T12:34:28.7077291Z =========================== short test summary info ===========================
2022-08-04T12:34:28.7077624Z FAILED test_ops_gradients.py::TestGradientsCPU::test_fn_fwgrad_bwgrad_nn_functional_prelu_cpu_float64
2022-08-04T12:34:28.7077943Z !!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!
2022-08-04T12:34:28.7078220Z !!!!!!!!!!!! xdist.dsession.Interrupted: stopping after 1 failures !!!!!!!!!!!!
2022-08-04T12:34:28.7078529Z = 1 failed, 321 passed, 1000 skipped, 7 xfailed, 49 warnings, 2 rerun in 50.24s =
2022-08-04T12:34:29.1323282Z Traceback (most recent call last):
2022-08-04T12:34:29.1323650Z   File "run_test.py", line 973, in <module>
2022-08-04T12:34:29.1323862Z     main()
2022-08-04T12:34:29.1324068Z   File "run_test.py", line 951, in main
2022-08-04T12:34:29.1324304Z     raise RuntimeError(err_message)
2022-08-04T12:34:29.1324535Z RuntimeError: test_ops_gradients failed!
2022-08-04T12:34:29.3280113Z 
2022-08-04T12:34:29.3280670Z (base) C:\actions-runner\_work\pytorch\pytorch\test>if ERRORLEVEL 1 goto fail 
2022-08-04T12:34:29.3282993Z 
2022-08-04T12:34:29.3283173Z (base) C:\actions-runner\_work\pytorch\pytorch\test>exit /b 1 
2022-08-04T12:34:29.3334138Z ##[error]Process completed with exit code 1.
2022-08-04T12:34:29.3460180Z Prepare all required actions
2022-08-04T12:34:29.3460700Z Getting action download info
2022-08-04T12:34:29.5132306Z Download action repository 'nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a' (SHA:71062288b76e2b6214ebde0e673ce0de1755740a)
2022-08-04T12:34:29.7582123Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-08-04T12:34:29.7582368Z with:

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 2, 2022
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

"""
# `tensor_list` would be understood as "tensor or list."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe then best name it more verbosely tensor_or_list to avoid understanding as "list of tensors"?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I changed the name to output to be generic (also to indicate it is an output).
In the documentation, I added that output can be a tensor or a list of tensor.

[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

"""
# `tensor_list` would be understood as "tensor or list."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc should be updated? Also, I'm not sure how much of a fan I am of an argument possibly having multiple types, and changing the behavior based on the type - it seems that it can be surprising / confusing to users. But, if we think this is the best way to consolidate the API, then I am open to it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I updated the document. Can you please take a look? Thank you!

Also filling in a bit of context:

  • In recent discussion with @ptrblck @crcrpar regarding use of the base feature in Megatron, it was identified that unifying the two APIs was preferred to productizing _all_gather_base. Also per the initial idea to unify them as mentioned by @zhaojuanmao when _all_gather_base was first implemented [torch distributed] Implementing all_gather_base (#56304) #56315 (review)

  • I also discussed with @ptrblck and @crcrpar regarding ways to unifying the two APIs.
    Method 1: have all_gather accept a single tensor
    Method 2: have all_gather accept a list, the list containing a single tensor
    It appears to me that Method 1 is cleaner in semantics -- when user passes in a single tensor, it is explicit that they ask for merging the input tensors into a (larger) output tensor. Also, since this is something user passes in rather than a return of the function, there is no need for user to figure out what the return is (a list or a tensor).

  • Regarding why a single function name:
    I want to give user an experience that they only need to think about what output format they want, and not having to additionally think about which function to call (especially when there are already several all_gather_* APIs there). In retrospective, it may also seem that the list implementation "hijacked" the all_gather name, whereas the single tensor implementation can be more natural, smoother in application use flow maybe, and more performant. Hence my thought of giving the tensor output the "first-class" name of all_gather.

I am happy to hear about comments on the above thoughts.

# `tensor_list` would be understood as "tensor or list."
# If it is a single tensor, we would route it to the _all_gather_base implementation.
if isinstance(tensor_list, torch.Tensor):
return _all_gather_base(tensor_list, tensor, group, async_op)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about a new flag into dist.all_gather such as inplace=True? And when this is provided, we document that the first arg should be an appropriately sized tensor and we fill the result in that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is tangential to in-place vs out-of-place. It is more about in which format the user wants to receive the all-gather result.

Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. I dropped some newbie questions :)



def all_gather(tensor_list, tensor, group=None, async_op=False):
def all_gather(output, tensor, group=None, async_op=False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO this looks backward compatibility breaking and I'm wondering if we could just make _all_gather_base public.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @crcrpar, I started writing my reply to @rohan-varma on why the unification and why such signature before seeing your comment. Do you mind checking if that responds to the second part of your question? Please feel free to let me know if you have any comment about that thought process.

Regarding backward compatibility, I will think more about it.

Comment on lines +2120 to +2123
# `output` can be a single tensor or a list of tensor.
# If it is a single tensor, we route it to the _all_gather_base implementation.
if isinstance(output, torch.Tensor):
return _all_gather_base(output, tensor, group, async_op)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my knowledge, _all_gather_base is not a custom op in terms of torch script while all_gather is according to #79669.

Could this mean that one function behaves slightly differently depending on the output?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make _all_gather_base a custom op in the backend too. That only concerns the internal implementation.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Sep 7, 2022
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: #85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
drisspg pushed a commit to drisspg/pytorch that referenced this pull request Sep 29, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in pytorch#33924 to avoid unnecessary flattening. There was previous effort (pytorch#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: pytorch#85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Oct 1, 2022

Closing this PR since an alternative solution (#85686) has been landed.

@kwen2501 kwen2501 closed this Oct 1, 2022
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: #85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants