Skip to content

Avoid unnecessary flattening in all-gather and reduce-scatter#33924

Closed
thorjohnsen wants to merge 13 commits intopytorch:masterfrom
thorjohnsen:avoid_unnecessary_flattening_for_allgather
Closed

Avoid unnecessary flattening in all-gather and reduce-scatter#33924
thorjohnsen wants to merge 13 commits intopytorch:masterfrom
thorjohnsen:avoid_unnecessary_flattening_for_allgather

Conversation

@thorjohnsen
Copy link
Copy Markdown
Contributor

This PR addresses an important special case, which comes up frequently as we are working towards model parallel training and a few other things. torch.distributed.all_gather takes two arguments, a list of output tensors and a single input tensor. The all_gather op collects inputs from all the ranks into the list of output tensors. Internally, the rank inputs are collected into a flattened tensor and then copied from there to the output tensors. Often, the list of output tensors are simply views into an already flattened tensor, in which case the un-flattening is unnecessary and we can do an in-place all-gather op instead. This saves memory and improves performance. This is particularly important when the all-gather is done within a single node, because all the reductions are done over high-speed nvlinks and the extra D2D copies really hurt performance. To demonstrate the difference this PR makes, I ran the test script shown below on a DGX1 with 8 x 32GB V100 cards, using NVIDIA's 19.11 devel image. Without this PR, it took 13.96 seconds (average of 5 runs) to complete 100 all-gathers of a flattened tensor with 1.07 billion floats (1024^3). With this PR, the same 100 all-gathers took 3.01 seconds, a 4.64x improvement.

The situation for reduce-scatter is fundamentally the same, but reversed. Instead of doing D2D copies after the op to un-flatten the output tensors, reduce-scatter has to do D2D copies before the op to flatten the input tensors. The performance overhead is the same as for all-gather. This PR fixes both all-gather and reduce-scatter.

This PR should have no side-effects. Everything should work exactly like before, but when the special case is detected, you will see a large bump in throughput.

Thanks to @alpha0422 for contributing the original code.

Command line:
python -m torch.distributed.launch --nnodes=1 --nproc_per_node=8 test_all_gather.py

test_all_gather.py

import time
import torch

def main(args):
    a = torch.randn([args.group_size*args.size]).float().cuda()
    a_vec = [a[i*args.size:(i+1)*args.size] for i in range(args.group_size)]
    if torch.distributed.get_rank() == 0:
        print("We used up %d mb of GPU memory." % (torch.cuda.memory_allocated()/(1024*1024)))
        print("Doing %d all-gathers of %d x %d floats" % (args.niter, args.group_size, args.size))
    torch.distributed.barrier()
    before = time.time()
    for iter in range(args.niter):
        torch.distributed.all_gather(a_vec,a_vec[torch.distributed.get_rank()])
    torch.cuda.synchronize()
    torch.distributed.barrier()
    after = time.time()
    print("rank %d :: a.norm() = %e, run_time was %.2f seconds" % (torch.distributed.get_rank(), a.norm(), after-before))

if __name__ == '__main__':
    torch.distributed.init_process_group("nccl")
    torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
    torch.manual_seed(888)
    class Args:
        size = 1024*1024*128
        niter = 100
        group_size = torch.distributed.get_world_size()
    main(Args())

@dr-ci
Copy link
Copy Markdown

dr-ci bot commented Feb 28, 2020

💊 Build failures summary and remediations

As of commit 1f674c8 (more details on the Dr. CI page):


  • 7/7 failures possibly* introduced in this PR
    • 2/7 non-CircleCI failure(s)

🕵️ 5 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (1/5)

Step: "Checkout pytorch/builder repo" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44
+ sleep 2 
+ git submodule update --init --recursive 
fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44 
Unable to checkout '815e209e4f30910dfdc47aa2ad041e3a46d61b44' in submodule path 'third_party/fbgemm' 
+ sleep 4 
+ git submodule update --init --recursive 
fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44 
Unable to checkout '815e209e4f30910dfdc47aa2ad041e3a46d61b44' in submodule path 'third_party/fbgemm' 
+ sleep 8 
+ git submodule update --init --recursive 
fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44 
Unable to checkout '815e209e4f30910dfdc47aa2ad041e3a46d61b44' in submodule path 'third_party/fbgemm' 

See CircleCI build caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build (2/5)

Step: "Build" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Apr 27 16:57:44 fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:373 
 
real	0m25.283s 
user	0m0.094s 
sys	0m0.033s 
Apr 27 16:57:29 ++ export BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Apr 27 16:57:29 ++ BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Apr 27 16:57:29 ++ git submodule sync 
Apr 27 16:57:29 ++ git submodule update -q --init --recursive 
Apr 27 16:57:44 fatal: reference is not a tree: 815e209e4f30910dfdc47aa2ad041e3a46d61b44 
Apr 27 16:58:18 Unable to checkout '815e209e4f30910dfdc47aa2ad041e3a46d61b44' in submodule path 'third_party/fbgemm' 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (3/5)

Step: "Build" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/dimensions.py 
Auto-merging .circleci/cimodel/data/dimensions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/caffe2_build_definitions.py 
Auto-merging .circleci/cimodel/data/caffe2_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_definitions.py 
Auto-merging .circleci/cimodel/data/binary_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/README.md 
Auto-merging .circleci/README.md 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_build (4/5)

Step: "Build" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Apr 27 16:59:43 caused by: Connection refused (os error 111)
Apr 27 16:59:43 +++ eval 'extract_trap_cmd ' 
Apr 27 16:59:43 ++++ extract_trap_cmd 
Apr 27 16:59:43 ++++ printf '%s\n' '' 
Apr 27 16:59:43 +++ printf '%s\n' cleanup 
Apr 27 16:59:43 ++ trap -- ' 
Apr 27 16:59:43 cleanup' EXIT 
Apr 27 16:59:43 ++ which sccache 
Apr 27 16:59:43 ++ sccache --stop-server 
Apr 27 16:59:43 Stopping sccache server... 
Apr 27 16:59:43 error: couldn't connect to server 
Apr 27 16:59:43 caused by: Connection refused (os error 111) 
Apr 27 16:59:43 ++ true 
Apr 27 16:59:43 ++ rm /var/lib/jenkins/sccache_error.log 
Apr 27 16:59:43 rm: cannot remove '/var/lib/jenkins/sccache_error.log': No such file or directory 
Apr 27 16:59:43 ++ true 
Apr 27 16:59:43 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log 
Apr 27 16:59:43 ++ SCCACHE_IDLE_TIMEOUT=1200 
Apr 27 16:59:43 ++ RUST_LOG=sccache::server=error 
Apr 27 16:59:43 ++ sccache --start-server 
Apr 27 16:59:43 Starting sccache server... 
Apr 27 16:59:43 ++ sccache --zero-stats 

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_build (5/5)

Step: "Build" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

C:\Jenkins\Miniconda3\Lib\site-packages\caffe2 : The system cannot find the path specified.
 
Add new data to archive: 0 files, 0 bytes 
 
 
Files read from disk: 0 
Archive size: 32 bytes (1 KiB) 
 
Scan WARNINGS for files and folders: 
 
C:\Jenkins\Miniconda3\Lib\site-packages\torch : The system cannot find the path specified. 
C:\Jenkins\Miniconda3\Lib\site-packages\caffe2 : The system cannot find the path specified. 
---------------- 
Scan WARNINGS: 2 
+ cleanup
+ retcode=1
+ set +x

Extra GitHub checks: 1 failed


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 57 times.

@mrshenli
Copy link
Copy Markdown
Contributor

cc @agolynski Does this PR has implications on all_gather_coalesced and all_gather_base/single?

@agolynski
Copy link
Copy Markdown
Contributor

cc @agolynski Does this PR has implications on all_gather_coalesced and all_gather_base/single?

This is a useful PR (at least in the current state of code).

ASIS, this does not seem to play well with coalesced impls as tensors are assumed to be of the same size, e.g.
t.storage_offset() != (tensor_lists[i][0].storage_offset() +
j * other[i].numel())

but i don't see a reason why this can't be extended to work with coalesced tensors as well.

The eventual direction is to have
all_gather_base which will guarantee that the outputs are 'inplace' instead of discovering it in code as in this PR. However, If clients call into PG all_gather with custom allocated output tensors, this optimization is still useful.
Does it make sense?

std::vector<at::Tensor>& other,
size_t world_size) {
size_t world_size,
size_t rank) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to simplify reasoning in upstream code, can you introduce
bool* inplace or bool* already_flat instead and fill it in as you do below for 'inplace'?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean introduce a new flag in python API so that "discovery" logic can be removed from ProcessGroupNCCL.cpp? In other words

torch.distributed.all_gather(output_tensors, input_tensor)

Will always flatten output tensors, while

torch.distributed.all_gather(output_tensors, input_tensor, already_flat=True)

will always do an inplace reduction?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be an option as well, but I think the point of this cl is to allow user code to run fast without necessary changing it explicitly, so they'll get it 'for free'.

I meant having output bool here, e.g.
auto outputFlattened = flatten_for_scatter_gather(
outputTensors, inputTensors, size, rank, *output_is_flat
);

and then use it instead of

if (outputFlattened[i][j].storage().is_alias_of(
outputTensors[i][j].storage()) &&
outputTensors[i][j].storage_offset() ==
(outputTensors[i][0].storage_offset() +
outputTensors[i][j].numel() * j)) {
break;
}

Would it work?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will work, but the logic that determines if the list of tensors are views into an already flattened tensor will then have to go into the code that implements each NCCL op, so we'd end up duplicating it. Unless I put the reasoning code in a separate method.

bool inplace = true;
for (auto j = size_t{}; j < tensor_lists[i].size(); ++j) {
auto t = tensor_lists[i][j];
if (!t.storage().is_alias_of(other[i].storage()) ||
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be confused about this. Why outputs should share storage with inputs? Don't we just want all outputs to share storage or this is a special case of allocating storage where outputs are combined with inputs (I didn't know this is allowed by NCCL engine!).
If that's the case, this saves some memory for separate input allocation in addition to output allocation!

However, for this specific case, maybe it's just enough to check if all outputs are properly aligned?
Do you know what would happen if outputs are aligned and share storage with inputs but inputs aren't properly placed within outputs, would NCCL crash or produce garbage outputs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL library allows inplace operations, it's explicitly mentioned in the documention. The inputs have to be properly aligned. I don't know what happens if they are not, other than it will not work

image

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. In-place will definitely save memory in addition to copying data.

I was wondering if we can save copying in case we've allocated output tensors separately from inputs, but outputs are still contiguous (that is, already flat). Something like:

input = torch.tensor(size)
....
outputs = torch.tensor(sizeworld_sizenum_devices)
pg.allgather(output, input)

Do you think this usecase is worth covering as well?

(Alternatively we can just ask users to copy 'input' at the appropriate location inside 'output' and change its storage accordingly.)

What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both cases should work. For all-gather, we can avoid copying to flatten and un-flatten if the outputs are views into an already flattened tensor. In addition, if the input is also a view into the same flattened tensor and input is properly aligned, we can do an in-place operation. The code in my PR already does this, but I don't verify that the input is correctly aligned. I will add code to verify proper alignment.

@vincentqb vincentqb added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 28, 2020
@agolynski
Copy link
Copy Markdown
Contributor

Would it be possible to add a test that exercises the new code path?

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

Absolutely! I can add a test script to verify correctness and demonstrate the perf benefit.

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

@mrshenli mentioned all_gather_coalesce as a potential issue. Is this something I need to address in this PR?

@agolynski
Copy link
Copy Markdown
Contributor

I've discussed this with @mrshenli offline and he mentioned that this might become a breaking change if the following holds:

  • operation is performed async
  • tensors are positioned for in-place operation
  • client changes input tensor as operation is in-flight

previous behavior: we make a copy and correct copy is being distributed
new behavior: undefined as tensor is being modified as op is in-flight

the suggestion is to have this in API explicitly, e.g. add optional bool in_place = false to python, and similarly for C++. You'll still get your performance gains by setting in_place = true, but we can avoid breaking clients.

What do you think?

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

I think that is a sensible compromise. I will make the change.

@thorjohnsen thorjohnsen requested a review from apaszke as a code owner March 5, 2020 15:52
@thorjohnsen
Copy link
Copy Markdown
Contributor Author

I have added inplace flag to torch.distributed.all_reduce and torch.distributed.all_gather. Please have a look and let me know if this is not what you had in mind. I am working on a couple of test scripts and will submit them when they are ready.

@agolynski
Copy link
Copy Markdown
Contributor

I have added inplace flag to torch.distributed.all_reduce and torch.distributed.all_gather. Please have a look and let me know if this is not what you had in mind. I am working on a couple of test scripts and will submit them when they are ready.

Looks good, thank you!

@ddkalamk
Copy link
Copy Markdown
Contributor

  • client changes input tensor as operation is in-flight

@agolynski Is this even allowed today from API point of view? As MPI backend simply enqueues the operation to MPI thread and returns. It may produce wrong results if input tensors are modified before wait is called and completed, I am missing something?

@ddkalamk
Copy link
Copy Markdown
Contributor

the suggestion is to have this in API explicitly, e.g. add optional bool in_place = false to python, and similarly for C++. You'll still get your performance gains by setting in_place = true, but we can avoid breaking clients.

IMHO, it is better to do in-place check internally as user can easily make a mistake here. And check itself would be very simple and inexpensive in most cases.

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

thorjohnsen commented Mar 10, 2020

Hi @ddkalamk the inplace flag simply signals that NCCL inplace op is allowed. This only affects NCCL backend. As you have pointed out, inplace NCCL op can potentially fail if user modifies input tensors while op is in flight, hence this is an opt-in feature, i.e. the inplace flag defaults to False. The code still checks internally if arguments can facilitate NCCL inplace op and will only schedule an inplace NCCL op if inplace == True and the arguments support it.

@ddkalamk
Copy link
Copy Markdown
Contributor

inplace NCCL op can potentially fail if user modifies input tensors while op is in flight, hence this is an opt-in feature

Modifying inputs while op is in flight is always disallowed whether we are doing in-place or out-of-place operation as we don't get to know when op has finished reading the input. Even MPI backend would potentially fail if one modifies input before op is complete. So, not sure how this argument helps. On the other hand, what happens if input and output buffers overlap and user doesn't specify inplace=True? Or the other way?

I am not familiar with NCCL but from API description it appears that if input and output pointers are overlapping and correctly aligned, it would implicitly do in-place operation. But least in MPI case, one need to explicitly specify MPI_IN_PLACE for send buf when using in-place operation.

@ddkalamk
Copy link
Copy Markdown
Contributor

previous behavior: we make a copy and correct copy is being distributed

Sorry, i think I misunderstood this comment. @agolynski do you mean if tensors are place for in-place operation (i.e. overlapping), we make a copy of input and to make it explicitly out-of-place operation? And to disable this behavior we need extra argument? Is this behavior specific to NCCL or applies to other backends as well?

@thorjohnsen thorjohnsen force-pushed the avoid_unnecessary_flattening_for_allgather branch 2 times, most recently from 11edb59 to b89fdc2 Compare March 14, 2020 03:29
@thorjohnsen
Copy link
Copy Markdown
Contributor Author

How do I proceed with this PR? My last 3 commits only changed whitespace or wording inside comments, before that all tests were passing except flake8. Flake8 is now passing, but some other tests are failing for reasons that seem unrelated to the code in this PR. For instance, caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build fails with the following error message:

Apr 27 16:58:18 Unable to checkout '815e209e4f30910dfdc47aa2ad041e3a46d61b44' in submodule path 'third_party/fbgemm'

815e209e4f30910dfdc47aa2ad041e3a46d61b44 is not one of my commits. Do I simply wait until these issues are resolved? I need to give progress report to management. Grateful for any help!

@pritamdamania87
Copy link
Copy Markdown
Contributor

@thorjohnsen PR looks good to me, could you rebase and resubmit and I'm happy to merge in it. @agolynski Let me know if you have any further comments on this PR.

if (!tensor_lists[i][j].storage().is_alias_of(tensor_lists[i][0].storage()) ||
tensor_lists[i][j].storage_offset() != (tensor_lists[i][0].storage_offset() +
j * tensor_lists[i][0].numel())) {
no_copy = false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be better to be explicit here and throw or at least warn if user specified no_copy but we can't actually do it?

@osalpekar
Copy link
Copy Markdown
Contributor

osalpekar commented Feb 26, 2021

Hi @thorjohnsen! Are you planning on continue work on this PR? I think we should be good to go after a rebase, and there are some important perf issues this can solve.

Otherwise we're happy to take on this work from here!

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

Hi @osalpekar I'll try to get this rebased right now.

@osalpekar
Copy link
Copy Markdown
Contributor

@thorjohnsen Thanks! Please let us know if you need any help.

@osalpekar
Copy link
Copy Markdown
Contributor

@thorjohnsen Let us know if you can rebase by this Friday. Since merging this is a bit urgent, I'll have to take over the PR on Friday

@thorjohnsen
Copy link
Copy Markdown
Contributor Author

@osalpekar If you have the time, I would be grateful if you can take over and bring this in. I am under a tight deadline and am unable to look at this PR until next week.

@slayton58
Copy link
Copy Markdown
Contributor

@osalpekar I've taken a stab at rebasing here: https://github.com/slayton58/pytorch/tree/avoid_unnecessary_flattening but I've been unable to get tests to even run (they either abort or hang on my system, no output) -- it might help you some.

@zhaojuanmao
Copy link
Copy Markdown
Contributor

Taking it over

@zhaojuanmao
Copy link
Copy Markdown
Contributor

@thorjohnsen I can not import your PR, would you please help signing the CLA? or do you have time to rebase this PR? thanks.

@zhaojuanmao
Copy link
Copy Markdown
Contributor

someone else is working on implementing all-gather-base that will just return a single flatten tensor, I think that implementation is more cleaner for use cases mentioned here, and can replace improvement in this PR.

thoughts? @agolynski

@facebook-github-bot
Copy link
Copy Markdown
Contributor

Hi @thorjohnsen!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 28, 2022
@github-actions github-actions bot closed this Jun 27, 2022
pytorchmergebot pushed a commit that referenced this pull request Sep 27, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: #85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
drisspg pushed a commit to drisspg/pytorch that referenced this pull request Sep 29, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in pytorch#33924 to avoid unnecessary flattening. There was previous effort (pytorch#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: pytorch#85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in #33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: #85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.