Skip to content

Conversation

@yifuwang
Copy link
Collaborator

@yifuwang yifuwang commented Oct 8, 2024

Stack from ghstack (oldest at bottom):

This Stack

Implement custom all-reduce algos available in IntraNodeComm as symm_mem ops and replace the existing IntraNodeComm kernels with them.

This PR

  • Replaces one-shot all-reduce with symm_mem::one_shot_all_reduce_out
  • Replaces two-shot all-reduce with symm_mem::two_shot_all_reduce_
  • Removes HCM all-reduce (at least for now). Due to the nature of its accumulation order, we can't guarantee the numerical consistency across all ranks.
  • Removes the IntraNodeComm python binding (its original purpose is superceded by SymmetricMemory).
  • Removes methods that were made for the python binding.
  • Replaces nvlink detection logic with DMAConnectivityDetector.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137475

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3a627a7 with merge base 9b2e453 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Oct 8, 2024
yifuwang pushed a commit that referenced this pull request Oct 8, 2024
…ing symm_mem ops"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 8, 2024
@yifuwang yifuwang requested review from Chillee and weifengpy October 8, 2024 19:10
@yifuwang yifuwang marked this pull request as ready for review October 8, 2024 19:10
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

pytorchmergebot pushed a commit that referenced this pull request Oct 9, 2024
…ating bfloat16 with multimem.ld_reduce (#137529)

This provides better accuracy without additional cost.

Also added documentation to `multimem_one_shot_all_reduce` to note the numerical caveats.
Pull Request resolved: #137529
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474, #137475
pytorchmergebot pushed a commit that referenced this pull request Oct 9, 2024
- Previously the detection would fail before user calling APIs such as `torch.cuda.set_device()`. This is because the detection logic requires nvml initialization. In this PR, we added explicit nvml initialization (which idempotent).
- Previously any nvml issue occurred in the detection logic would result in fatal error. Now we issue an informative warning and return a topology assuming no NVLink connectivity.

Pull Request resolved: #137530
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474, #137475, #137529
jackzhxng pushed a commit that referenced this pull request Oct 16, 2024
…m ops (#137475)

## This Stack

Implement custom all-reduce algos available in `IntraNodeComm` as `symm_mem` ops and replace the existing `IntraNodeComm` kernels with them.

## This PR
- Replaces one-shot all-reduce with `symm_mem::one_shot_all_reduce_out`
- Replaces two-shot all-reduce with `symm_mem::two_shot_all_reduce_`
- Removes HCM all-reduce (at least for now). Due to the nature of its accumulation order, we can't guarantee the numerical consistency across all ranks.
- Removes the `IntraNodeComm` python binding (its original purpose is superceded by `SymmetricMemory`).
- Removes methods that were made for the python binding.
- Replaces nvlink detection logic with `DMAConnectivityDetector`.

Pull Request resolved: #137475
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474
jackzhxng pushed a commit that referenced this pull request Oct 16, 2024
…ating bfloat16 with multimem.ld_reduce (#137529)

This provides better accuracy without additional cost.

Also added documentation to `multimem_one_shot_all_reduce` to note the numerical caveats.
Pull Request resolved: #137529
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474, #137475
jackzhxng pushed a commit that referenced this pull request Oct 16, 2024
- Previously the detection would fail before user calling APIs such as `torch.cuda.set_device()`. This is because the detection logic requires nvml initialization. In this PR, we added explicit nvml initialization (which idempotent).
- Previously any nvml issue occurred in the detection logic would result in fatal error. Now we issue an informative warning and return a topology assuming no NVLink connectivity.

Pull Request resolved: #137530
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474, #137475, #137529
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants