Skip to content

[torch distributed] Implementing all_gather_base (#56304)#56315

Closed
liangluofb wants to merge 1 commit intopytorch:masterfrom
liangluofb:export-D27488999
Closed

[torch distributed] Implementing all_gather_base (#56304)#56315
liangluofb wants to merge 1 commit intopytorch:masterfrom
liangluofb:export-D27488999

Conversation

@liangluofb
Copy link
Copy Markdown
Contributor

Summary:
Pull Request resolved: #56304

This diff implements the all_gather_base in pytorch distributed.

Test Plan: dist.all_gather_base(output, input)...

Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 17, 2021

💊 CI failures summary and remediations

As of commit 98c7966 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@facebook-github-bot facebook-github-bot added oncall: distributed Add this issue/PR to distributed oncall triage queue fb-exported labels Apr 17, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

1 similar comment
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 17, 2021

Codecov Report

Merging #56315 (59f4e8f) into master (ffdecc1) will decrease coverage by 0.00%.
The diff coverage is 21.05%.

❗ Current head 59f4e8f differs from pull request most recent head 19a8630. Consider uploading reports for the commit 19a8630 to get more accurate results

@@            Coverage Diff             @@
##           master   #56315      +/-   ##
==========================================
- Coverage   77.02%   77.02%   -0.01%     
==========================================
  Files        1924     1923       -1     
  Lines      190590   190521      -69     
==========================================
- Hits       146803   146748      -55     
+ Misses      43787    43773      -14     

@zhaojuanmao zhaojuanmao requested a review from agolynski April 19, 2021 18:04
Copy link
Copy Markdown
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liangluofb thanks for your contribution! the codes overall look good, I left some minor comments.

Also we've discussed internally, we would like to make 'all_gather_base' to be private APIs and named as '_all_gather_base' in both python and C++. As @agolynski had a C10d API cleanup plans to merge all_gather and all_gather_base. Before that cleanup is done, we would like not to introduce a new set of APIs.

So would you please help changing the 'all_gather_base' to be '_all_gather_base' in both python and C++? thanks

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

Copy link
Copy Markdown
Contributor

@agolynski agolynski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LG, just some minor comments thank you!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

Summary:
Pull Request resolved: pytorch#56315

This diff implements the all_gather_base in pytorch distributed.

Test Plan: dist.all_gather_base(output, input)...

Reviewed By: agolynski, amylittleyang

Differential Revision: D27488999

fbshipit-source-id: 4b286563301324f9212c0e31f8540712453441e2
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D27488999

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been merged in c370957.

krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Pull Request resolved: pytorch#56315

This diff implements the all_gather_base in pytorch distributed.

Test Plan: dist.all_gather_base(output, input)...

Reviewed By: agolynski, amylittleyang

Differential Revision: D27488999

fbshipit-source-id: 937ec8bddf9527fa4d114f984d1d0f6a5b8c3936
@zarzen
Copy link
Copy Markdown

zarzen commented Jul 16, 2021

Hi,
would you mind to change the function signature of _allgather_base to accept two list tensors? so that multiple tensor allgather can be launched in a single ncclGroup? and the current collective interface requires std::vector<at::Tensor> as inputs anyway.
I can submit a PR if you would agree on the change.

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::_allgather_base(
     std::vector<at::Tensor>& outputs,
    std::vector<at::Tensor>& inputs,
    const AllgatherOptions& /*unused */)

At the python front end we can do a type checking, if the input is a single tensor then wrap it with the brackets.

def allgather_base(self, outputs, inputs, ... ):
    if type(output) != list:
        outputs = [outputs]
        inputs = [inputs]
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed fb-exported Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants