Skip to content

Conversation

@daisyden
Copy link
Collaborator

@daisyden daisyden commented Jul 17, 2025

For #114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

  • instantiate_device_type_tests()
  • use "torch.accelerator.current_accelerator()" to determine the accelerator backend
  • enabled XPU for some test path
  • Unify some common code under torch/testing/_internal for multiple backend, for example:
    • requires_nccl_version
    • _dynamo_dist_per_rank_init
    • DynamoDistributedSingleProcTestCase
    • DistTestCases
    • FSDPTestMultiThread

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @gujinghui @EikanWang @fengyuan14 @guangyey

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158533

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 32111ce with merge base 9d37c96 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jul 17, 2025
@daisyden daisyden changed the title port 3 distributed test to Intel GPU and unified some common functions [WIP] port 3 distributed test to Intel GPU and unified some common functions Jul 17, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Jul 17, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Jul 17, 2025
@etaf etaf added the ciflow/xpu Run XPU CI tasks label Jul 17, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Jul 17, 2025
@daisyden
Copy link
Collaborator Author

@pytorchbot label "module: xpu"
@pytorchbot label "triaged"

@pytorch-bot pytorch-bot bot added the module: xpu Intel XPU related issues label Jul 17, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Jul 17, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Jul 18, 2025
@jingxu10 jingxu10 added the ciflow/xpu Run XPU CI tasks label Jul 18, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Jul 23, 2025
Comment on lines 1479 to 1481
backend = c10d.get_default_backend_for_device(
torch.accelerator.current_accelerator().type
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend = c10d.get_default_backend_for_device(
torch.accelerator.current_accelerator().type
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = c10d.get_default_backend_for_device(device_type)

Otherwise, torch.accelerator.current_accelerator() will return None if no accelerator detected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks!

guangyey
guangyey previously approved these changes Jul 23, 2025
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit, otherwise LGTM.

@guangyey guangyey changed the title [WIP] port 3 distributed test to Intel GPU and unified some common functions port 3 distributed test to Intel GPU and unified some common functions Jul 23, 2025
@guangyey guangyey moved this to Review Required in PyTorch Intel Jul 23, 2025
@guangyey guangyey dismissed their stale review July 23, 2025 09:23

Please help check the CI failure.

Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Jul 24, 2025
@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

Successfully rebased daisyden/dist_upstream_s1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout daisyden/dist_upstream_s1 && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the daisyden/dist_upstream_s1 branch from 323ada3 to 7d7531e Compare August 12, 2025 02:00
@guangyey
Copy link
Collaborator

@daisyden Please fix the lint error.

@guangyey
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-project-automation github-project-automation bot moved this from Review Required to Done in PyTorch Intel Aug 13, 2025
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
#158533)

For #114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: #158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <[email protected]>
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
#158533)

For #114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: #158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <[email protected]>
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
pytorch#158533)

For pytorch#114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: pytorch#158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <[email protected]>
pytorchmergebot pushed a commit that referenced this pull request Sep 6, 2025
For #114850, we will port distributed tests to Intel GPU. This PR is created base on PR #158533 and #159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: #161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
daisyden added a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
For pytorch#114850, we will port distributed tests to Intel GPU. This PR is created base on PR pytorch#158533 and pytorch#159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: pytorch#161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
pytorch#158533)

For pytorch#114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: pytorch#158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <[email protected]>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
For pytorch#114850, we will port distributed tests to Intel GPU. This PR is created base on PR pytorch#158533 and pytorch#159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: pytorch#161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
For pytorch#114850, we will port distributed tests to Intel GPU. This PR is created base on PR pytorch#158533 and pytorch#159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: pytorch#161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
For pytorch#114850, we will port distributed tests to Intel GPU. This PR is created base on PR pytorch#158533 and pytorch#159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: pytorch#161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
For pytorch#114850, we will port distributed tests to Intel GPU. This PR is created base on PR pytorch#158533 and pytorch#159473 and will work on some test files under test/distributed/fsdp. We could enable Intel GPU with following methods and try the best to keep the original code styles in this PR:

1. add allow_xpu=True in instantiate_device_type_tests() if needed.
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend

3. enabled XPU for some test path

Pull Request resolved: pytorch#161601
Approved by: https://github.com/guangyey, https://github.com/d4l3k
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged module: xpu Intel XPU related issues oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants